{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "pytorch-unet-resnet18-colab.ipynb",
      "provenance": [],
      "authorship_tag": "ABX9TyMT+Pc8v3Z6njXdLUMBVh3M",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4lcY1ziTLblo",
        "colab_type": "text"
      },
      "source": [
        "## pytorch-uent\n",
        "\n",
        "https://github.com/usuyama/pytorch-unet"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yUvckFGU-4HE",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "8a844bf3-e9de-464f-ec07-f0c164d3794e"
      },
      "source": [
        "import os\n",
        "\n",
        "if not os.path.exists(\"pytorch_unet.py\"):\n",
        "  if os.path.exists(\"pytorch_unet\"):\n",
        "    !git clone https://github.com/usuyama/pytorch-unet.git\n",
        "\n",
        "  %cd pytorch-unet"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/pytorch-unet\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EAx84Zg1_RnV",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "outputId": "d93de9f3-4581-429a-f28a-af19ded2a8f0"
      },
      "source": [
        "!ls"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "checkpoint.pth\tloss.py\t\t\t     pytorch-unet\t simulation.py\n",
            "helper.py\t__pycache__\t\t     pytorch_unet.ipynb\n",
            "images\t\tpytorch_fcn.ipynb\t     pytorch_unet.py\n",
            "LICENSE\t\tpytorch_resnet18_unet.ipynb  README.md\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N90-BlegJZfs",
        "colab_type": "text"
      },
      "source": [
        "## Enabling GPU on Colab\n",
        "\n",
        "Need to enable GPU from Notebook settings\n",
        "\n",
        "- Navigate to Edit-Notebook settings menu\n",
        "- Select GPU from the Hardware Accelerator dropdown list\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HCitpQdkJNdI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "dd501c2c-e042-4826-e1c6-dc571b5f502e"
      },
      "source": [
        "import torch\n",
        "\n",
        "if not torch.cuda.is_available():\n",
        "  raise Exception(\"GPU not availalbe. CPU training will be too slow.\")\n",
        "\n",
        "print(\"device name\", torch.cuda.get_device_name(0))"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "device name Tesla P100-PCIE-16GB\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v8nZ6_mKMsJs",
        "colab_type": "text"
      },
      "source": [
        "## Synthetic images for demo training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6qt0VHVZ_53z",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "729453ea-5f84-4957-da1e-9fe1eca815e5"
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import helper\n",
        "import simulation\n",
        "\n",
        "# Generate some random images\n",
        "input_images, target_masks = simulation.generate_random_data(192, 192, count=3)\n",
        "\n",
        "print(\"input_images shape and range\", input_images.shape, input_images.min(), input_images.max())\n",
        "print(\"target_masks shape and range\", target_masks.shape, target_masks.min(), target_masks.max())\n",
        "\n",
        "# Change channel-order and make 3 channels for matplot\n",
        "input_images_rgb = [x.astype(np.uint8) for x in input_images]\n",
        "\n",
        "# Map each channel (i.e. class) to each color\n",
        "target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "input_images shape and range (3, 192, 192, 3) 0 255\n",
            "target_masks shape and range (3, 6, 192, 192) 0.0 1.0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t16ni593BSUE",
        "colab_type": "text"
      },
      "source": [
        "# Left: Input image (black and white), Right: Target mask (6ch)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dzjh6C1HBTCb",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 704
        },
        "outputId": "6dcecbc7-901b-4d34-8661-198b92ab1370"
      },
      "source": [
        "helper.plot_side_by_side([input_images_rgb, target_masks_rgb])"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdsAAAKvCAYAAAAiIWV+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dfaxcd33n8c/nOoU/klQJZNYyTlwn1ICg2lxg5K1SEqUEugkCQkBKbVVgIOIabaJut5XaPFQFtaKilDTaim3gRrESr8CEbZKS7XoL2Qg1TRtKrsEYBwixU0exMfZNXEF4UFp7vvvHPRNObuZ6Hs75+TzM+yWNZuY358z5zrV/5zO/M+fBESEAAJDOTNUFAADQdoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJJQtb25fbfsz2PtvXp1oOAAB15xTH2dpeJel7kt4q6aCkRyRtjohvl74wAABqLtXIdqOkfRHxRET8m6TPS7oy0bIAAKi10xK971pJT+WeH5T0n1aa2DanscI0ezoiOlUXUZZzzjkn1q9fX3UZQCV27do1sD+nCtuhbM9Jmqtq+UCNPFl1AUXl+/O6deu0sLBQcUVANWwP7M+pNiMfknRe7vm5WdvzImI+IroR0U1UA4BTJN+fO53WDNKB0qQK20ckbbB9vu2XSNok6b5EywIAoNaSbEaOiOO2r5P0JUmrJG2LiEdTLAsAgLpL9pttROyUtDPV+wMA0BScQQoAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACCxicPW9nm2v2L727Yftf1fs/aP2j5ke3d2e1t55QIA0DynFZj3uKTfi4iv2z5T0i7b92ev3RIRnyxeHgAAzTdx2EbEYUmHs8fP2v6OpLVlFQYAQFuU8put7fWSXi/pn7Om62zvsb3N9tllLAMAgKYqHLa2z5B0t6TfiYgfSbpV0islzWpp5HvzCvPN2V6wvVC0BgDVyvfnxcXFqssBascRMfnM9i9I+ltJX4qIvxjw+npJfxsRvzLkfSYvAmi+XRHRrbqIsnS73VhY4Ds0xjdKHtk+BZVMzvbA/jzxb7Ze+sS3S/pOPmhtr8l+z5WkqyTtnXQZwDDDOmdEaGaGI9yAOuv34633fH/otJ959ysk1T90lyuyN/KvSXqvpG/Z3p213Shps+1ZSSHpgKSthSoEBhh1i4xtRQShi9prw6huEhExUsj29af9zLtf0ai/R5G9kR+SNOiT7py8HIxjWkd1gz73oLZ8R7StXq/Xyr8Hmi8i9ONvrho63RkXnmhUwAwzbtDmbb3n+40K3CIjW1Ro1G/BbQuY5Z+7/3zQZ+z1epJ+Hrpt/Hug+UYNWkn68TdXtSZwTxa08+9ZO9K0TQpc1joNNM5Obf2AaYNBQTszM7NiePZfy8/Xpr8Hmm+coO378TdXjbUOqKOVwnP+PWtfFLTSUr+df8/a53+vzdt6z/cb8fcgbBtmkv9UbQiYlYJ2FAQu6miSoO1rcuCeLGiHsb1i4NYdYdsgBQ/Tak3ATPJb9PLABapUJGj7is5fJ6MEbd9KgVv3/k3YovbynajITl/5+dr05QNoikGj2nGCtm9Q4NZ9dEvYYqrU/dsvgHYibNEYZRzKxJ7IQH1MMqrtW2lzcl2x5gEAIDHCFgCAxAhbAAASI2wBAJWYph0WCdsGmab/mIOUcbjOtP8NgTopcrhOkfMqV4GwbZAiJ2Zo8kUJCEi0jW2dceGJQu9x5mw7jhOflv7dzLXvFJskcJsctFJ5J6Mo6+QYQBmKBG5Tg3alk1FMsk4r4+QYpxJrmwYaJ3DbEipFz208Ld+e0SyTBG5Tg/Zkxgncpm0+7mv+WnhK9QN32K0NQSsVu5hAkYsYAKmNE7htCNqTXUygv94apP/apBcxqFrh69naPiDpWUknJB2PiK7tl0m6S9J6SQckXR0R/1p0WXihaQuMmZkZ9Xq9F1yf9mRfKla6yPy0/d1Qf7ZbEaSj6gfu8uDsPx/nyj5NCFpJctHNa1nYdiPi6VzbJyQdi4iP275e0tkR8QcneQ+28WFk+cAdR42DdldEdKsuoizdbjcWFhaqLgMNUHSTcB2D1vbA/pxqzXOlpDuzx3dKelei5WAK5Tehj6Jtm9SBtihyfuM6Bu3JFN6MLCkkfTkbnX4mIuYlrY6Iw9nrP5C0uoTlAM/rB+cov9sSskB92db8e9aOPMptWsj2lRG2b4qIQ7b/g6T7bX83/2JExKDNxLbnJM2VsHxMMYK0HvL9ed26dRVXgybqh25bFV5TRcSh7P6opHslbZR0xPYaScrujw6Ybz4ium36rQqYVvn+3Ol0qi4HqJ1CYWv7dNtn9h9L+g1JeyXdJ2lLNtkWSV8sshwAAJqs6Gbk1ZLuzfYMPU3S5yLi72w/IukLtq+R9KSkqwsuBwCAxioUthHxhKQLB7Q/I+myIu8NAEBbsHcJAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYhNfPN72qyXdlWu6QNIfSTpL0ockLWbtN0bEzokrBACg4SYO24h4TNKsJNleJemQpHslfUDSLRHxyVIqBACg4crajHyZpP0R8WRJ7wcAQGuUFbabJO3IPb/O9h7b22yfXdIyAABopMJha/slkt4p6X9lTbdKeqWWNjEflnTzCvPN2V6wvVC0BgDVyvfnxcXF4TMAU6aMke0Vkr4eEUckKSKORMSJiOhJuk3SxkEzRcR8RHQjoltCDQAqlO/PnU6n6nKA2ikjbDcrtwnZ9prca1dJ2lvCMgAAaKyJ90aWJNunS3qrpK255k/YnpUUkg4sew0AgKlTKGwj4ieSXr6s7b2FKgIAoGU4gxQAAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBiI4Wt7W22j9rem2t7me37bT+e3Z+dtdv2X9reZ3uP7TekKh4AgCYYdWR7h6TLl7VdL+mBiNgg6YHsuSRdIWlDdpuTdGvxMgEAaK6RwjYiHpR0bFnzlZLuzB7fKeldufbtseSrks6yvaaMYgEAaKIiv9mujojD2eMfSFqdPV4r6ancdAezNgAAplIpO0hFREiKceaxPWd7wfZCGTUAqE6+Py8uLlZdDlA7RcL2SH/zcHZ/NGs/JOm83HTnZm0vEBHzEdGNiG6BGgDUQL4/dzqdqssBaqdI2N4naUv2eIukL+ba35ftlfyrkn6Y29wMAMDUOW2UiWzvkHSppHNsH5T0EUkfl/QF29dIelLS1dnkOyW9TdI+ST+V9IGSawYAoFFGCtuI2LzCS5cNmDYkXVukKAAA2oQzSAEAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAUFBEaOkUA8BghG1FIkK9Xq/qMgAUFBHa/tBF2v7QRVWXghojbAEASIywBQAgMcIWAIDECFsAABIjbAEASIywBQAgMcIWAIDEhoat7W22j9rem2v7c9vftb3H9r22z8ra19v+me3d2e3TKYsHAKAJRhnZ3iHp8mVt90v6lYj4j5K+J+mG3Gv7I2I2u324nDIBAGiuoWEbEQ9KOras7csRcTx7+lVJ5yaoDQAq0z8F47DbJPNwasfpc1oJ7/FBSXflnp9v+xuSfiTpDyPiH0pYRiOM24FsjzxPRGhmhp/YgVOhfwrGcd35D6PNY0nvu/ifRpvWHrsO1E+hsLV9k6Tjkj6bNR2WtC4inrH9Rkl/Y/t1EfGjAfPOSZorsvy6GSds8x0oIp4P3pXugTrL9+d169ZVXE01IqSFv3nVyNN/+gN/OtJ0D++/adKSUCMTh63t90t6u6TLIkuDiHhO0nPZ412290t6laSF5fNHxLyk+ey9WpEm44w8+wHaH7H2er2h95JGmo4RME61fH/udruN78+2teXih0eaNj8K3vXFlcM2FDp00UGt/adzZTFanTYTrZVtXy7p9yW9MyJ+mmvv2F6VPb5A0gZJT5RRaBvlNw2vdG/7BQE66j2AGrJ06KKDCjX++wjGNMqhPzskPSzp1bYP2r5G0qcknSnp/mWH+FwiaY/t3ZL+WtKHI+LYwDfG80EqacX7UacbNB+AGiJwp9LQzcgRsXlA8+0rTHu3pLuLFjUtRhnZ9o07wgVQY1ngskl5erBmrtCoI9txpmdkCzQEI9ypQthWaJyRbd+oI1wADUDgTg3WzBUad2RbdD4ANUTgTgXCtkKTjGz7ho1wATQIgdt6rJkrVHSEerL5ATQMgVuKcU6ZeSpPoVnG6RoxoSIj276VRrgY3aSdjFNoonTspVxIROiRN7154vk3/uNXSqzmhVhTVKS/gi/jt9dBI1wADcUIt5UI24rMzMw8f+s/H3Q/quUjXACnRv/UjltGvLDAaG+6FLhoD9bMFStzr+L8CBdAw1k6+GtPcSGSliBsK1b28bIELtAeodDM/DsI3BYgbCuW4nhZAhdoFwK3+QjbiqU6ExSBC7QLgdtshG3FUp4JisAF2oXAbS7CtmKjjGzzF5of9T5/kDaBC9REjHcbdKztzPw7TlW1KBEntahY/3Cdle77J04YNt3yewD1Ylnn/tN5Y8/31PZPD58ItcdauWKjXp923HsAQH0MHdna3ibp7ZKORsSvZG0flfQhSYvZZDdGxM7stRskXSPphKTfjogvJagbAGrFth7ef1PVZaCmRhkG3SHp8gHtt0TEbHbrB+1rJW2S9Lpsnr+yvaqsYgEAaKKhYRsRD0o6NuL7XSnp8xHxXET8i6R9kjYWqA8AgMYr8gPfdbb32N5m++ysba2kp3LTHMzagNqq4+W4ALTLpHsj3yrpT7S0g/qfSLpZ0gfHeQPbc5LmJlw+UBp2Kisu35/XrVtXcTWYVraTXiaviInWMhFxJCJORERP0m36+abiQ5Ly+7afm7UNeo/5iOhGRHeSGgDUR74/dzqdqssBameisLW9Jvf0Kkl7s8f3Sdpk+6W2z5e0QdLXipUIAECzjXLozw5Jl0o6x/ZBSR+RdKntWS1tRj4gaaskRcSjtr8g6duSjku6NiJOpCkdAIBmGBq2EbF5QPPtJ5n+Y5I+VqQoAADahD1DAABIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEhsatra32T5qe2+u7S7bu7PbAdu7s/b1tn+We+3TKYsHAKAJThthmjskfUrS9n5DRPxm/7HtmyX9MDf9/oiYLatAAACabmjYRsSDttcPes22JV0t6c3llgUAQHsU/c32YklHIuLxXNv5tr9h++9tX1zw/QEAaLxRNiOfzGZJO3LPD0taFxHP2H6jpL+x/bqI+NHyGW3PSZoruHwANZDvz+vWrau4GqB+Jh7Z2j5N0rsl3dVvi4jnIuKZ7PEuSfslvWrQ/BExHxHdiOhOWgOAesj3506nU3U5QO0U2Yz8FknfjYiD/QbbHdursscXSNog6YliJQIA0GyjHPqzQ9LDkl5t+6Dta7KXNumFm5Al6RJJe7JDgf5a0ocj4liZBQMA0DSj7I28eYX29w9ou1vS3cXLAgCgPTiDFAAAiRG2AAAkRtgCAJAYYQsAQGKELQAAiRG2AAAkRtgCAJAYYQsAQGKELQAAiRG2AAAkRtgCAJCYI6LqGmR7UdJPJD1ddS0lOEd8jjppwuf4pYhozXXpbD8r6bGq6yhBE/7vjILPcWoN7M+1CFtJsr3Qhmvb8jnqpS2fo0na8jfnc9RL0z8Hm5EBAEiMsAUAILE6he181QWUhM9RL235HE3Slr85n6NeGv05avObLQAAbVWnkS0AAK1E2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACSWLGxtX277Mdv7bF+fajkAANSdI6L8N7VXSfqepLdKOijpEUmbI+LbpS8MAICaSzWy3ShpX0Q8ERH/Junzkq5MtCwAAGotVdiulfRU7vnBrA0AgKlzWlULtj0naS57+saq6gBq4OmI6FRdRBH5/nz66ae/8TWveU3FFQHV2LVr18D+nCpsD0k6L/f83KzteRExL2lekmyX/8Mx0BxPVl1AUfn+3O12Y2FhoeKKgGrYHtifU21GfkTSBtvn236JpE2S7ku0LAAAai3JyDYijtu+TtKXJK2StC0iHk2xLAAA6i7Zb7YRsVPSzlTvDwBAU3AGKQAAEiNsAQBIjLAFgAEiQinOsIfpRNiWKCLU6/WqLgNAQRGh7Q9dpO0PXVR1KWgJwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMSSXfUHAOpmktMvjjOP7bHfH9OBsB1i3M5pe+R5IkIzM2xcAE6F/ikYxzXOPFsufnjs98d0IGyHmORbLScvBwDkTRy2ts+TtF3SakkhaT4i/rvtj0r6kKTFbNIbswvJN9I4I8/+VUIYrQL1Y3vkkWd+FMxoFWUoMrI9Lun3IuLrts+UtMv2/dlrt0TEJ4uXBwBA800cthFxWNLh7PGztr8jaW1ZhQEA0BalbO+0vV7S6yX9c9Z0ne09trfZPruMZQAA0FSFw9b2GZLulvQ7EfEjSbdKeqWkWS2NfG9eYb452wu2F4rWAKBa+f68uLg4fAZgyhQKW9u/oKWg/WxE3CNJEXEkIk5ERE/SbZI2Dpo3IuYjohsR3SI1AKhevj93Op2qywFqZ+Kw9dJxLrdL+k5E/EWufU1usqsk7Z28PAAAmq/I3si/Jum9kr5le3fWdqOkzbZntXQ40AFJWwtVCABAwxXZG/khSYPOTdbYY2oBAEiBsy8AAJAYYVsiTtMIABiEcyOXiNM0Au0wzqkdgVGQDgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJMYZpABgRKOcknXp6qPACxG2iQ3rnBHBaR6Bmuv34633fH/otJ959yskEbp4IcI2kVEvSmBbEUHoAjU0Tsj29acldJFH2JbsZCGbf215ByR0gXqJiLFCdrl86BK4KBy2tg9IelbSCUnHI6Jr+2WS7pK0XtIBSVdHxL8WXVbdDQrafFs+RHu93vOP8x3Rtnq9HoELVGiUoJ1/z9qRpt16z/cJXMhFr8GahW03Ip7OtX1C0rGI+Ljt6yWdHRF/cJL3aPyFYFcK2lFCs9frvagjMsKdKrsiolt1EWXpdruxsLBQdRkTO1l49gM2xbxoB9sD+3OqtfmVku7MHt8p6V2JllNb44TlzMwMF54Ham6UsLT9/G+1y9HHp1sZYRuSvmx7l+25rG11RBzOHv9A0uoSllNbyzvRJKPS5YHb35wM4NRZaWQ6zqh0pcAt8vsvmq+MsH1TRLxB0hWSrrV9Sf7FWEqQF32lsz1ne8F2c7c3DVBk8y8jXDRVvj8vLi5WXU6pJtn8u1Lg0r+nV+GwjYhD2f1RSfdK2ijpiO01kpTdHx0w33xEdJv+W1W+85TxO2t+fka3aIp8f+50OlWXM5FBo9oiv7MOClxGt9OrUDLYPt32mf3Hkn5D0l5J90nakk22RdIXiywHAIAmK3roz2pJ92Z70p4m6XMR8Xe2H5H0BdvXSHpS0tUFlzNVIoLDBACgRQqFbUQ8IenCAe3PSLqsyHs3TZmH6vDbLVC9Mg7V6W9KZvMxOJATAIDECFsAGICtSygTYVuSMvccppMD1Stj02/R8yujPQhbAAASI2wLYhQKtBf9G2UhbAsq+yQUZZ8kA8BoVjoJxaSBW/ZJMtBsrMlLUNY5jfkWDdTPJIHLb7VYjrAtwfLR5ySBW8bFDAAUc7KLCIwauGVczADtw9q8JMs74jiBy4gWqL9+4PZvefl2RrQYpOjpGpGZmZl50UXgbT/fKZePVE8WsIxqgeqc7KxP+bb8CHhYwDKqBWFbon7gSnrRuY3zwbuS/usELVCtUU6zOOoIlqCFRNiWrh+UK4XuIIQsUD+2Nf+etRNvGiZkkUfYJrI8dEeZFkD9jBu6hCwGIWwTI0iBduiHLjAJkgAAgMQIWwAAEpt4M7LtV0u6K9d0gaQ/knSWpA9JWszab4yInRNXCABAw00cthHxmKRZSbK9StIhSfdK+oCkWyLik6VUCABAw5W1GfkySfsj4smS3g8AgNYoK2w3SdqRe36d7T22t9k+u6RlAADQSIXD1vZLJL1T0v/Kmm6V9EotbWI+LOnmFeabs71ge6FoDQCqle/Pi4uLw2cApkwZI9srJH09Io5IUkQciYgTEdGTdJukjYNmioj5iOhGRLeEGgBUKN+fO51O1eUAtVNG2G5WbhOy7TW5166StLeEZQAA0FiFziBl+3RJb5W0Ndf8CduzkkLSgWWvAQAwdQqFbUT8RNLLl7W9t1BFAAC0DGeQAgAgMcIWAIDECFsAABIjbAEASIywBQAgMcIWAIDECFsAABIjbAEASIywBQAgMcIWAIDECFsAABIjbAEASKzQhQgAAOWLiBVfs30KK0FZCNuWWKlzRoRmZtiAATRFROjH31y14utnXHiCwG0g1sItMOxbcK/XO4XVAJjUsKCVpB9/c9VJ+zzqibBtuFE6HYEL1N8oQdtH4DbPSGFre5vto7b35tpeZvt+249n92dn7bb9l7b32d5j+w2pip9243Q2Aheor3GCtm/c6VGtUUe2d0i6fFnb9ZIeiIgNkh7InkvSFZI2ZLc5SbcWLxMAgOYaKWwj4kFJx5Y1XynpzuzxnZLelWvfHku+Kuks22vKKBYAgCYq8pvt6og4nD3+gaTV2eO1kp7KTXcwawMAYCqVsoNULP14ONav9bbnbC/YXiijBgDVyffnxcXFqssBaqdI2B7pbx7O7o9m7YcknZeb7tys7QUiYj4iuhHRLVADgBrI9+dOp1N1OUDtFAnb+yRtyR5vkfTFXPv7sr2Sf1XSD3Obm1Eidv0HgGYY9dCfHZIelvRq2wdtXyPp45LeavtxSW/JnkvSTklPSNon6TZJ/6X0qiFJmpmZGTlwOZMUUF+2dcaFJ8aa58xZDuVrkpFO1xgRm1d46bIB04aka4sUhdHNzMyo1+ud9PRtBC1Qf/3AHeX4WYK2eVgDt0B/hLvSjaAFmmGUES5B20xciKAlCFSgHWwTqC3EGhoAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEhsatra32T5qe2+u7c9tf9f2Htv32j4ra19v+2e2d2e3T6csHgAwfU52/e5ht6qMcj3bOyR9StL2XNv9km6IiOO2/0zSDZL+IHttf0TMllolUNCknSwiuFYwUCMRoUfe9OaJ59/4j18psZrRDV2LRMSDko4ta/tyRBzPnn5V0rkJagMAoBXK+Mr+QUn/N/f8fNvfsP33ti8u4f0BAGi0UTYjr8j2TZKOS/ps1nRY0rqIeMb2GyX9je3XRcSPBsw7J2muyPIB1EO+P69bt67iaoD6mXhka/v9kt4u6bci+0EsIp6LiGeyx7sk7Zf0qkHzR8R8RHQjojtpDQDqId+fO51O1eUAtTNR2Nq+XNLvS3pnRPw0196xvSp7fIGkDZKeKKNQAEDzVLkHcJ2McujPDkkPS3q17YO2r9HS3slnSrp/2SE+l0jaY3u3pL+W9OGIODbwjYEV9Hq9qksAUIKI0Mz8OwhcjfCbbURsHtB8+wrT3i3p7qJFYbrZVq/X45AboCVm5t+h3tz/lu2qS6kMazPUUj9wAbTDtI9wCVvUFoELtMs0By5hi1ojcIF2mdbAJWxRewQu0C7TGLiELRqBwAXaZdoCl7BFYxC4QLtMU+AStmiUSQO3aZfjAqbFtARuoXMjA1WY5DhcjtkF6muc43BtV3aZvCJYA6GR2KQMtEvbR7iELRqLwAXapc2BS9ii0QhcoF3aGriELRqPwAXapY2BS9iiFQhcoF3aFriELVqDwAXapU2BS9iiVQhcoF3aEriELWpn0hNQ5E9EQeAC7TEz/46qSyhs6EktbG+T9HZJRyPiV7K2j0r6kKTFbLIbI2Jn9toNkq6RdELSb0fElxLUjRbjBBRAO9hWbP3bqsuohVHWandIunxA+y0RMZvd+kH7WkmbJL0um+evbK8qq1gAAJpoaNhGxIOSjo34fldK+nxEPBcR/yJpn6SNBeoDAKDximyvu872HtvbbJ+dta2V9FRumoNZGwAAU2vSsL1V0islzUo6LOnmcd/A9pztBdsLE9YAoCby/XlxcXH4DMCUmShsI+JIRJyIiJ6k2/TzTcWHJJ2Xm/TcrG3Qe8xHRDciupPUAKA+8v250+lUXQ5QOxOFre01uadXSdqbPb5P0ibbL7V9vqQNkr5WrEQAAJptlEN/dki6VNI5tg9K+oikS23PSgpJByRtlaSIeNT2FyR9W9JxSddGxIk0pQMA0AxDwzYiNg9ovv0k039M0seKFAUAQJtw9gAAABIjbAEASIywBQAgMcIWAIDECFsAABIbujcyAABVKOM6trZLqKQ4what0uv1NDMzU/geQLUiQhf98p8Wfp+H999UQjXFsVZBq/SDsug9AJSJNQtapdfrlXIPAGUibNEqjGwB1BFrFrQKI1sAdUTYolUY2QKoI9YsaBVGtgDqiLBFqzCyBVBHrFnQKoxsAdQRYYtWYWQLoI6Grllsb7N91PbeXNtdtndntwO2d2ft623/LPfap1MWDyzHyBZAHY1yusY7JH1K0vZ+Q0T8ZliH6AoAABSxSURBVP+x7Zsl/TA3/f6ImC2rQGAcjGwB1NHQNUtEPCjp2KDXvHSG56sl7Si5LmAijGwB1FHRr/EXSzoSEY/n2s63/Q3bf2/74oLvD4yFkS2AOiq6ZtmsF45qD0taFxGvl/S7kj5n+xcHzWh7zvaC7YWCNQDPY2RbjXx/XlxcrLocoHYmDlvbp0l6t6S7+m0R8VxEPJM93iVpv6RXDZo/IuYjohsR3UlrAJZjZFuNfH/udDpVlwPUTpE1y1skfTciDvYbbHdsr8oeXyBpg6QnipUIjI6RLYA6Gro3su0dki6VdI7tg5I+EhG3S9qkF+8YdYmkP7b975J6kj4cEQN3rgJSYGQLtIPt2lz4vQxDwzYiNq/Q/v4BbXdLurt4WQAAtAdf4wEASIywBQAgMcIWAIDECFsAABIjbAEASIywBQAgMcIWAIDECFsAABIjbAEASIywBQAgMcIWAIDECFsAABJzRFRdg2wvSvqJpKerrqUE54jPUSdN+By/FBGtuQis7WclPVZ1HSVowv+dUfA5Tq2B/bkWYStJthfacCF5Pke9tOVzNElb/uZ8jnpp+udgMzIAAIkRtgAAJFansJ2vuoCS8DnqpS2fo0na8jfnc9RLoz9HbX6zBQCgreo0sgUAoJUIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsAQBIjLAFACAxwhYAgMSSha3ty20/Znuf7etTLQcAgLpzRJT/pvYqSd+T9FZJByU9ImlzRHy79IUBAFBzqUa2GyXti4gnIuLfJH1e0pWJlgUAQK2dluh910p6Kvf8oKT/lJ/A9pykuezpGxPVATTB0xHRqbqIIvL9+fTTT3/ja17zmoorAqqxa9eugf05VdgOFRHzkuYlyXb527KB5niy6gKKyvfnbrcbCwsLFVcEVMP2wP6cajPyIUnn5Z6fm7UBADB1UoXtI5I22D7f9kskbZJ0X6JlAQBQa0k2I0fEcdvXSfqSpFWStkXEoymWBQBA3SX7zTYidkramer9AQBoCs4gBQBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkNjEYWv7PNtfsf1t24/a/q9Z+0dtH7K9O7u9rbxyAQBontMKzHtc0u9FxNdtnylpl+37s9duiYhPFi8PAIDmmzhsI+KwpMPZ42dtf0fS2rIKAwCgLUr5zdb2ekmvl/TPWdN1tvfY3mb77BXmmbO9YHuhjBoAVCffnxcXF6suB6idwmFr+wxJd0v6nYj4kaRbJb1S0qyWRr43D5ovIuYjohsR3aI1AKhWvj93Op2qywFqp1DY2v4FLQXtZyPiHkmKiCMRcSIiepJuk7SxeJkAADRXkb2RLel2Sd+JiL/Ita/JTXaVpL2TlwcAQPMV2Rv51yS9V9K3bO/O2m6UtNn2rKSQdEDS1kIVAgDQcEX2Rn5Ikge8tHPycgAAaB/OIAUAQGKELQAAiRG2AAAkRtgCAJAYYQsAQGJFDv0BgFMmIoZOs3T4P1A/rQ7bYZ0zIjQzw+AeqLuI0CNvevPQ6Tb+41dOQTXA+EgaAAASI2wBAEiMsAUAIDHCFgCAxAhbAAASa/XeyABwMqMcTrQSDjPCOAhbAFMpInTRL//pxPM/vP+mEqtB27V6M3JEnPQ2iV6vV+geADB9Wj2yTXHCiv57TnoPYHy2OWEFGq1wAtg+YPtbtnfbXsjaXmb7ftuPZ/dnFy+1HhjZAgDGVdZw69cjYjYiutnz6yU9EBEbJD2QPW8FRrYAgHGlSoArJd2ZPb5T0rsSLeeUY2QLABhXGWEbkr5se5ftuaxtdUQczh7/QNLq5TPZnrO90N/03BSMbIEXy/fnxcXFqssBaqeMBHhTRLxB0hWSrrV9Sf7FWNrt90W7/kbEfER0c5ueG4GRLfBi+f7c6XSqLgeoncJhGxGHsvujku6VtFHSEdtrJCm7P1p0OXXByBaoj2GH9xU5aQVQpkKH/tg+XdJMRDybPf4NSX8s6T5JWyR9PLv/YtFCy1L0Gre9Xk8zMzMT3wMoR0Ro6z3fHzrdZ979Cs72hMoVXfuvlvSQ7W9K+pqk/xMRf6elkH2r7cclvSV7XrlRvuXaPukmX0a2QLX6I9ZRglaStt7zfUa5qFyhkW1EPCHpwgHtz0i6rMh7l2l5JztZp7Mt2893zuUhycgWqM44IZvXn2f+PWvLLgkYSavPIDXIKJuJT7bJiZEtUC8rBeigYI4INimjEq1PgPwodljQSkuh2J9n0CZl9kYGqjEoPE82UrWtz7z7FS9o629SBk61VoftuEHbd7LAZWQLnHrjBm3fSoELnGpTsRl5nKDtywcugHoZ57fXfuAuD1nbXCYPpwzDLQAAEiNsAQBIjLAFACAxwhYAgMQIWwAAEiNsATQORwqgaVodtic7OcWo8wKon3FOTjHpKR6BMrU6bIedDWolk54MA0Aak54NatKTYQBla32K5INylMAlaIF6OlngrnQjaFEXU3MGKemFV/QZdR4A9cYmYjTBVIRtf3Q67Io+fYxogXqyrfn3rB37d1hGtKjaVCVK/zfcYTeCFqi3QZuUV0LQog4mHtnafrWku3JNF0j6I0lnSfqQpMWs/caI2DlxhSUjSIF26I9ygSaYOGwj4jFJs5Jke5WkQ5LulfQBSbdExCdLqRAAgIYr6zfbyyTtj4gnR/lNFABQL6PsFMr6fXJlhe0mSTtyz6+z/T5JC5J+LyL+dfkMtuckzZW0fFRglGMc2Ww/HfL9ed26dRVXg3H0+/GPv7lq6LRnXHhCEqE7CRc9xMX2SyR9X9LrIuKI7dWSnpYUkv5E0pqI+OCQ9+A4mwYZ9/8MoTvUrojoVl1EWbrdbiwsLFRdBkYQESOF7HJnXHiCwF2B7YH9uYw14BWSvh4RRyQpIo5ExImI6Em6TdLGEpaBmpjky9kkp8sEkNakQSstjYI5F8F4ygjbzcptQra9JvfaVZL2lrAM1ECRzkXgAvVRJGj7CNzxFApb26dLequke3LNn7D9Ldt7JP26pP9WZBmohzI6FYELVK+MoO0r632mQaEdpCLiJ5JevqztvYUqAgA0RkTw++0I2GsFQ5W5qYjRLVCdMke1fYxuR0PYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRthiKU7IBQDGELYYq84o9XAEIqI7t5y+TV5YzZzlJzShY62EkjG4BYHKELUYyMzNTOHAZ1QLVK3N0y6h2dKz5MLIigUvQAvVRRuAStONh7Yex9AN31NDtT0vQAvVSJHAJ2vEVusQeplM/OEe5eg8hC9SXbZ052xv5akCE7OQIW0yMIAXaoR+6SGektaXtbbaP2t6ba3uZ7fttP57dn5212/Zf2t5ne4/tN6QqHgCAJhh1aHKHpMuXtV0v6YGI2CDpgey5JF0haUN2m5N0a/EyAQBorpHCNiIelHRsWfOVku7MHt8p6V259u2x5KuSzrK9poxiAQBooiI/uq2OiMPZ4x9IWp09Xivpqdx0B7O2F7A9Z3vB9kKBGgDUQL4/Ly4uVl0OUDul7OESS8eBjHUAZkTMR0Q3Irpl1ACgOvn+3Ol0qi4HqJ0iYXukv3k4uz+atR+SdF5uunOzNgAAplKRsL1P0pbs8RZJX8y1vy/bK/lXJf0wt7kZAICpM9JxtrZ3SLpU0jm2D0r6iKSPS/qC7WskPSnp6mzynZLeJmmfpJ9K+kDJNQMA0CgjhW1EbF7hpcsGTBuSri1SFAAAbcIpgAAASIywBQAgMcIWAIDECFsAABIjbAEAIxvnetb4OcIWyUXESNe+BVBvEaHtD12k7Q9dVHUpjUPYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBiQ8PW9jbbR23vzbX9ue3v2t5j+17bZ2Xt623/zPbu7PbplMUDAIrpn+t4lNuk83EuZem0Eaa5Q9KnJG3Ptd0v6YaIOG77zyTdIOkPstf2R8RsqVWiVibpOLbHmi8iNDPDhhcgpf65jicx7nxbLn54ouW0xdC1WUQ8KOnYsrYvR8Tx7OlXJZ2boDbU1KTfaPkWDGBajTKyHeaDku7KPT/f9jck/UjSH0bEPwyayfacpLkSlo9TbNwRZz9AGam2V74/r1u3ruJqMCrbY4048yPhaR+pjqvQ2s/2TZKOS/ps1nRY0rqIeL2k35X0Odu/OGjeiJiPiG5EdIvUAKB6+f7c6XSqLgeonYnD1vb7Jb1d0m9Ftt0vIp6LiGeyx7sk7Zf0qhLqBACgsSYKW9uXS/p9Se+MiJ/m2ju2V2WPL5C0QdITZRQKAEBTDf3N1vYOSZdKOsf2QUkf0dLexy+VdL9tSfpqRHxY0iWS/tj2v0vqSfpwRBwb+MYA0AARoWw9B0xsaNhGxOYBzbevMO3dku4uWlTT9Xo9dgYCWiAiNDP/DvXm/jeBi0JIhARsq9frVV0GgJLMzL+DQ9JQCGGbCIELtAuBiyII24QIXKBdCFxMirBNjMAF2oXAxSQI21OAwAXahcDFuAjbU2SaA5eVEtqIwMU4yjg3MkbUD9xpOyxo2j4vpse0HRY07rmU8XOsBU+xaR7hAm3ECBejIGwrQOAC7ULgYhjCNoFRr9dK4ALtMTP/jqpLQI3xm20C/EYJtINtxda/rboMtACpAABAYoQtAACJEbYAACRG2AIAkBhhCwBAYkPD1vY220dt7821fdT2Idu7s9vbcq/dYHuf7cds/+dUhQMA0BSjjGzvkHT5gPZbImI2u+2UJNuvlbRJ0uuyef7K9qqyigUAoImGhm1EPCjp2Ijvd6Wkz0fEcxHxL5L2SdpYoD4AABqvyG+219nek21mPjtrWyvpqdw0B7O2F7E9Z3vB9kKBGgDUQL4/Ly4uVl0OUDuThu2tkl4paVbSYUk3j/sGETEfEd2I6E5YA4CayPfnTqdTdTlA7UwUthFxJCJORERP0m36+abiQ5LOy016btYGAMDUmihsba/JPb1KUn9P5fskbbL9UtvnS9og6WvFSgQAoNmGXojA9g5Jl0o6x/ZBSR+RdKntWUkh6YCkrZIUEY/a/oKkb0s6LunaiDiRpnQAAJphaNhGxOYBzbefZPqPSfpYkaIAAGgTziAFAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQ2NCwtb3N9lHbe3Ntd9nend0O2N6dta+3/bPca59OWTwAAE1w2gjT3CHpU5K29xsi4jf7j23fLOmHuen3R8RsWQUCANB0Q8M2Ih60vX7Qa7Yt6WpJby63LAAA2qPob7YXSzoSEY/n2s63/Q3bf2/74pVmtD1ne8H2QsEaAFQs358XFxerLgeonaJhu1nSjtzzw5LWRcTrJf2upM/Z/sVBM0bEfER0I6JbsAYAFcv3506nU3U5QO1MHLa2T5P0bkl39dsi4rmIeCZ7vEvSfkmvKlokAABNVmRk+xZJ342Ig/0G2x3bq7LHF0jaIOmJYiUCANBsoxz6s0PSw5Jebfug7WuylzbphZuQJekSSXuyQ4H+WtKHI+JYmQUDANA0o+yNvHmF9vcPaLtb0t3FywIAoD04gxQAAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQGGELAEBihC0AAIkRtgAAJEbYAgCQmCOi6hpke1HSTyQ9XXUtJThHfI46acLn+KWIaM1FYG0/K+mxqusoQRP+74yCz3FqDezPtQhbSbK90IYLyfM56qUtn6NJ2vI353PUS9M/B5uRAQBIjLAFACCxOoXtfNUFlITPUS9t+RxN0pa/OZ+jXhr9OWrzmy0AAG1Vp5EtAACtVHnY2r7c9mO299m+vup6xmH7gO1v2d5teyFre5nt+20/nt2fXXWdy9neZvuo7b25toF1e8lfZv8+e2y/obrKX2iFz/FR24eyf5Pdtt+We+2G7HM8Zvs/V1N1u9GfTz36czP6c6Vha3uVpP8h6QpJr5W02fZrq6xpAr8eEbO5XdKvl/RARGyQ9ED2vG7ukHT5sraV6r5C0obsNifp1lNU4yju0Is/hyTdkv2bzEbETknK/l9tkvS6bJ6/yv7/oST058rcIfpz7ftz1SPbjZL2RcQTEfFvkj4v6cqKayrqSkl3Zo/vlPSuCmsZKCIelHRsWfNKdV8paXss+aqks2yvOTWVntwKn2MlV0r6fEQ8FxH/Immflv7/oTz05wrQn5vRn6sO27WSnso9P5i1NUVI+rLtXbbnsrbVEXE4e/wDSaurKW1sK9XdxH+j67JNZNtym/2a+Dmapul/Y/pzPbWiP1cdtk33poh4g5Y2zVxr+5L8i7G0q3fjdvduat2ZWyW9UtKspMOSbq62HDQI/bl+WtOfqw7bQ5LOyz0/N2trhIg4lN0flXSvljZjHOlvlsnuj1ZX4VhWqrtR/0YRcSQiTkRET9Jt+vmmpUZ9joZq9N+Y/lw/berPVYftI5I22D7f9ku09IP3fRXXNBLbp9s+s/9Y0m9I2qul+rdkk22R9MVqKhzbSnXfJ+l92V6Mvyrph7nNU7Wz7Penq7T0byItfY5Ntl9q+3wt7SDytVNdX8vRn+uD/lw3EVHpTdLbJH1P0n5JN1Vdzxh1XyDpm9nt0X7tkl6upb3/Hpf0/yS9rOpaB9S+Q0ubZP5dS791XLNS3ZKspT1M90v6lqRu1fUP+Rz/M6tzj5Y65Jrc9Ddln+MxSVdUXX8bb/TnSmqnPzegP3MGKQAAEqt6MzIAAK1H2AIAkBhhCwBAYoQtAACJEbYAACRG2AIAkBhhCwBAYoQtAACJ/X9ckXs/r3devQAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 576x864 with 6 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qg2FqLRGBEJT",
        "colab_type": "text"
      },
      "source": [
        "## Prepare Dataset and DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_-UTr03eAROb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from torch.utils.data import Dataset, DataLoader\n",
        "from torchvision import transforms, datasets, models\n",
        "\n",
        "class SimDataset(Dataset):\n",
        "  def __init__(self, count, transform=None):\n",
        "    self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count)\n",
        "    self.transform = transform\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.input_images)\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "    image = self.input_images[idx]\n",
        "    mask = self.target_masks[idx]\n",
        "    if self.transform:\n",
        "      image = self.transform(image)\n",
        "\n",
        "    return [image, mask]\n",
        "\n",
        "# use the same transformations for train/val in this example\n",
        "trans = transforms.Compose([\n",
        "  transforms.ToTensor(),\n",
        "  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet\n",
        "])\n",
        "\n",
        "train_set = SimDataset(2000, transform = trans)\n",
        "val_set = SimDataset(200, transform = trans)\n",
        "\n",
        "image_datasets = {\n",
        "  'train': train_set, 'val': val_set\n",
        "}\n",
        "\n",
        "batch_size = 25\n",
        "\n",
        "dataloaders = {\n",
        "  'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),\n",
        "  'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)\n",
        "}"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BtkJTyxGB-XB",
        "colab_type": "text"
      },
      "source": [
        "## Check the outputs from DataLoader"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CRIOwoQvBKPm",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 304
        },
        "outputId": "b8f5a8ef-a44e-4082-fae3-e7de4cbe7c25"
      },
      "source": [
        "import torchvision.utils\n",
        "\n",
        "def reverse_transform(inp):\n",
        "  inp = inp.numpy().transpose((1, 2, 0))\n",
        "  mean = np.array([0.485, 0.456, 0.406])\n",
        "  std = np.array([0.229, 0.224, 0.225])\n",
        "  inp = std * inp + mean\n",
        "  inp = np.clip(inp, 0, 1)\n",
        "  inp = (inp * 255).astype(np.uint8)\n",
        "\n",
        "  return inp\n",
        "\n",
        "# Get a batch of training data\n",
        "inputs, masks = next(iter(dataloaders['train']))\n",
        "\n",
        "print(inputs.shape, masks.shape)\n",
        "\n",
        "plt.imshow(reverse_transform(inputs[3]))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.image.AxesImage at 0x7fa7176db978>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAWFklEQVR4nO3de3RU5bnH8e+ThEkIRFAJgQIe0HIxiCCl6tHSYrWKtprisR7tOS2oFXuU1VZbL/WG1WVX1Xps6x2PCNoC1WO9rC7v2NZ6EBFQUbQgEFAoQg3gLTEE8pw/ZmOHZGJCZiY7mff3Wetd2fPO7D3PNuHnvr/m7ohIuAriLkBE4qUQEAmcQkAkcAoBkcApBEQCpxAQCVzOQsDMJprZCjNbZWaX5Op7RCQzlovrBMysEFgJfA1YD7wEnO7ub2T9y0QkI7naEjgUWOXua9x9OzAPqMrRd4lIBopytNwBwDspr9cDh7X0YTPTZYsiufeeu5c37cxVCLTKzKYCU+P6fpEArUvXmasQ2AAMSnk9MOr7lLvPAGaAtgRE4pSrYwIvAUPNbIiZJYDTgEdz9F0ikoGcbAm4+w4zmwY8CRQCM919eS6+S0Qyk5NThHtchHYHRDrCEncf17RTVwyKBE4hIBI4hYBI4BQCIoFTCIgETiEgEjiFgEjgFAIigVMIiAROISASOIWASOAUAiKBUwiIBE4hIBI4hYBI4BQCIoFTCIgETiEg7da3b1+Kioro1asXZWVlJBIJysuTT7Tu168fAH369KG4uJiePXvSq3dvioqK6Nu3b5xlSxMKAWm3V5ctY9SoUdx0001cfvnljB8/nv9bsIDi4mLWvf02paWlPPfXv3L00Udz8cUXc9tttzFixAiWL9fjJjsTPWNQJBzZfcagmQ0ysz+Z2RtmttzMfhj1X2VmG8zslaidkEnV0nmtXrOG0aNHc/vtt3PNNdcwYcIEXn31VRLFxWzZupXu3buzZMkSjjvuOK644gpm3nMPlZWVrF2XdgwMiYu7t6sB/YGx0XQZyQFIK4GrgJ/s4bJcreu1AysrvaSkxAcOHOj9+vXznj17+vDhw93MfNSoUW5mPmzYMC8rK/OKigoftN9+XlJS4pWVlbHXHmhbnO7fX7u3BNx9o7svjaY/BN4kOQahBOKCCy6gf//+nFRVxTHHHMOQIUOYNm0ahYWFXHbZZXTr1o1zzz2XoUOHctRRR3HypEmU9+3LTy68MO7SJUVWDgya2WDgEODFqGuamS0zs5lmtncL80w1s8VmtjgbNUjHW1tdTX19PZs3baKmpobaujrWrVuHu7N69WrcnbfffpuPP/6Ymi1b2LRpE9vr66lesybu0iVFxiFgZj2BB4EfufsHwO3AAcAYYCNwY7r53H2Gu49Ld6BCuoZFL73ERx99xFurVrF27Vre37aNpUuX4u4sWLCAxsZGlr78Mlu3buXtdetYuXIltbW1LFq0KO7SJVV7jwlE+/LdSA41dkEL7w8GXtcxgfxsLy5a5JUjR/r111/vF19yiR9x5JH+7J/+5IlEwlesXOklJSX+1NNP+4SjjvILLrjAf/WrX/mwYcN8yZIlsdceaEt7TKDdpwjNzIDZwBZ3/1FKf3933xhNnw8c5u6ntbKs9hUhserevTv19fUUFRXh7uzcuZNEcTGf1NVRWlpKbW0tJSUlNDQ0UFBQgJnR0NBASUkJdXV1cZcforSnCDPZCvgSyXRZBrwStROA+4DXov5Hgf7aEsjP9o/33vOxY8f6rNmz/YYbbvBjjjnG16xZ48XFxb6zsdFLS0t9xcqV/vWvf92vvfZanzdvno8aNcq3bN0ae+2BtuxuCWSTtgS6pkQiQUNDA4WFhQDsbGykW1ER27dvp7i4mPr6ehKJBDt27MDMMDN27txJt0SC7fX1MVcfJA1IKtn1wsKFVFZW8ovrruPCiy7iiCOO4Jn580kkErz55puUdO/OE08+yYQJEzj//PO56aabGDZsGC/pwGCnoi0BabevfOUrLF26lP3224/t27dTU1PDgZWVLFiwgKO/+lWeffZZDv/Xf2XlihX06t2b0u7dqa6uZty4cfz5z3+Ou/ysKCgo4LHHHyd5iKy5+vp6TjrxxA6uqkXZPSaQzUb8+0pqanvcevbs6b+88Ubf2djoje5pW/327X7DDTd4UVFR7PXSwjGB2ANAIaDWFVufPn38vPPOa/Eff9N25plnes+ePeOuWwcGRbLl8MMPZ8ELL+zRPAcccEDcV0vqwKCINKcQEAmcQkAkcAoBkcApBEQCpxAQCZxCQKQd3n33Xe666642f/7WW2/lg/ffz2FFGYj7QiFdLKTWVVvv3r19/vz5n3nFYMOOHf7MM894IpGIvV50sVDX8bkBA/jH5s00NDTEXYq0orCwkFWrV1NQkH6j+pO6OoYPH97BVbUo7cVCCoFOqGbLFo477jiWLllCY2Nj3OVI/tAVg13Jiy++yKmnnhp3GRIAhUAnZWbMvOcerr766rhLkTynEOjESkpK+K9zz+XOO++MuxTJYwqBTm7fffflpKoq7rjjjrhLkTyVjXEH1prZa9G4g4ujvn3M7Gkzeyv6mXYAEmmbiooKvnXqqVx2+eVxlyL5KAvn+NcCfZr0XQ9cEk1fAlyn6wTa3mq2bEl7zrnuk0980qRJXlBQEHuNal2yZXcswlZUkRyTgOjnN3P0PUEpLi7mwT/8gYNHj6Zbt25xlyN5Ihsh4MBTZrbEzKZGfRUeDUACvAtUZOF7JLJ06VJGjBjx6aO+RTKRjRD4kruPBY4HzjOzL6e+6cntfW86kwYkzcyry5Zx7LHHxl2G5IGMQ8DdN0Q/NwMPAYcCm8ysPySHJQM2p5lPA5Jm6JFHH+XHP/5x3GVIF5dRCJhZDzMr2zUNHAu8TnL4scnRxyYDj2TyPZJeUVERl19xBTffckvcpUgXVpTh/BXAQ9HAC0XAHHd/wsxeAu43s7OAdYCuf82RXr168a1TTqGkuJizzz477nKkC8ooBNx9DTA6TX8NcHQmy5a261tRwUlVVWzZupWLL7oo7nKki9EVg3mivLycc889lzPPOqvFIbFE0lEI5JEePXowY8YMjjjiCIqKMt3Tk1AoBPJMQUEBf33+eYYNG6YgkDZRCOSp15cv57DDDtOugbRKIZDH/vLcc5xx5plxlyGdnEIgjxUUFPDrX/+a666/Pu5SpBNTCOS5Hj16MGXKlD16PLaERUeOOqFzpk4lkUhkdZmbNje7clsE0NOGRUKipw2LSHMKAZHAKQREAqcQEAmcQkAkcAoBkcApBEQCpxAQCZxCQCRwCgGRwLX73gEzGw78PqVrf+BKoDdwNvCPqP9Sd3+s3RWKSE5l5d4BMysENgCHAWcAH7n7L/dgft07IJJ7Ob134Ghgtbuvy9LyRKSDZCsETgPmpryeZmbLzGymhiUX6dwyDgEzSwAnAQ9EXbcDBwBjgI3AjS3Mp7EIRTqBjI8JmFkVcJ67Nxsd08wGA39094NaWYaOCYjkXs6OCZxOyq7AroFII5NIjk0oIp1URo8XiwYh/RpwTkr39WY2huRw5GubvCcinYweLyYSDj1eTESaUwiIBE4hIBI4hYBI4BQCIoFTCIgETiEgEjiFgEjgFAIigVMIiAROISASOIWASOAUAiKBUwiIBE4hIBI4hYBI4BQCIoFTCIgETiEgErg2hUA0iMhmM3s9pW8fM3vazN6Kfu4d9ZuZ/cbMVkUDkIzNVfEikrm2bgnMAiY26bsEmO/uQ4H50WuA44GhUZtKcjASEemk2hQC7v4csKVJdxUwO5qeDXwzpf9eT1oI9G4yFoGIdCKZHBOocPeN0fS7QEU0PQB4J+Vz66M+EemEMhp8ZBd39z0dO8DMppLcXRCRGGWyJbBp12Z+9HNz1L8BGJTyuYFR327cfYa7j0s3GIKIdJxMQuBRYHI0PRl4JKX/u9FZgsOB91N2G0Sks3H3VhvJAUc3Ag0k9/HPAvYleVbgLeAZYJ/oswbcCqwGXgPGtWH5rqamlvO2ON2/P41FKBIOjUUoIs0pBEQCpxAQCZxCQCRwCgGRwCkERAKnEBAJnEJAJHAKAZHAKQREAqcQEAmcQkAkcAoBkcApBEQCpxAQCZxCQCRwCgGRwCkERAKnEJAuKZFIMG/ePObNm0d5eXnc5XRpesagdEmlpaV89PHHAAwZPJh169bFXFGX0L5nDLYwGOkNZva3aMDRh8ysd9Q/2MzqzOyVqN2R3XUQkWxry+7ALJoPRvo0cJC7HwysBH6a8t5qdx8Tte9np0wRyZVWQyDdYKTu/pS774heLiQ5ypCIdEHZODB4JvB4yushZvaymf3FzMa3NJOZTTWzxWa2OAs1iEg7ZTQgqZldBuwAfhd1bQT2c/caM/sC8LCZjXT3D5rO6+4zgBnRcnRgUCQm7d4SMLMpwDeA//BdY4m517t7TTS9hORQZMOyUKeI5Ei7QsDMJgIXASe5e21Kf7mZFUbT+wNDgTXZKFREcqPV3QEzmwtMAPqY2XpgOsmzAcXA02YGsDA6E/Bl4GozawAage+7+5a0CxaRTqHVEHD309N0393CZx8EHsy0KAlXeXk5Eyc2PSPdXKK4+NPpk08+mffee6/VeRYsWMDq1aszqi8vtWVo8lw34h+yWa2TtCOPPNIb3XPSpkyZEvv6xdzSDk2e0dkBkWxr2LGDmpqaVj9nZuyzzz4AbNu2jZ07d7Y6T319fcb15SPdOyBdku4daJf23TsgIvlNISASOIWASOAUAiKBUwiIBE4hIBI4hYBI4BQCIoFTCIgELqjLhtdv2MBee+3V5s/PnTuXc6ZOzWFF8N3Jk7nlllta/dza6moOPvjgnNbSldTW1lLWs+en09J+QYVAj9JSekZ/OG1RnHKnWrbde999jBs3jl69erWppmHDh/PGm28CcMiYMboOHvg4umxYMqPdgRjcfffdTJw4kREjRtC/f/82zZNIJBgxYgQjRozggQceoKKiIsdVSiiC2hLoDK688kr+7ZRTdtsteeedd/ifu+5qcZ6hQ4fyn9/5zqevv3HiiVx66aXcfPPNrFq1Kqf1SgDifpZARz5PYOvWrXt0//ms2bOz9t1FRUU+adIk397QsNt3rKmu9p///OefOe8hhxzijz/xRLP6rr7mGh86dGjc96irdZ2W9nkCsQdAKCFQVlbWbPkb/v53/9nPftam+QcOHOjL33ij2TIuvOiiuP+w1LpOSxsCOibQAQoKCpod/KutreWq6dOZPn16m5axfv16vjB2LB9++OGu4ASSBy9LSkqyWq8Epg3/l54JbAZeT+m7CtgAvBK1E1Le+ymwClgBHKctAfzAAw/cbbk7Gxt9/Pjx7V7ex7W1vrOx8dPl/W7OnLj/D6PWNVq7twRm0XwsQoCb/J9jDj4GYGaVwGnAyGie23Y9glz+af8hQ3j++efbPX/vXr2orq7OYkUSsnaNRfgZqoB5nhyEpJrkFsGhGdSXlxoaGnbbpG/P/HSCx8JJfsjkmMC0aGjymWa2d9Q3AHgn5TProz4R6aTaGwK3AwcAY0iOP3jjni5AA5KKdA7tuljI3Tftmjazu4A/Ri83AINSPjow6ku3jA4fkPTe++6je/fubf78whdeyGE1Ip1Du0LAzPq7+8bo5STg9Wj6UWCOmf038DmSYxEuyrjKLPnhD34QdwkinU57xyKcYGZjSJ52WAucA+Duy83sfuANkkOWn+furY8KISKxyepYhNHnrwWuzaSofFdYlNktG4WFhZAcCFYkY7piMAbV1dWMHz++3fO//8EHDBkyJIsVSdDivm+gI68YjKsVFBT4oEGDdrtq8P0PPvCzvve9PVpOSUmJ12zZstvVgldOn+6lpaWxr6Nal2i6dyAujY2NbNu2bbe+srIypk+fzhVXXtmmZQwYMICFCxey9957Yym7AnV1dXqyjmQm7q2AELYEIHkr8be//e1mtxKvfOstnz59+mfOe/Do0f7Qww83u7fhuuuu8wMPPDD2dVPrMk1Dk8dpx44dzJkzh5EHHcS0adMoKysD4POf/zxTzjiD2rq6FucdPnw4VVVVu/XNuPNOZs6cyYoVK3JatwQg7q2AULYEUtt9v/2tv1dTs0d3NKa2J5980vv16xf7eqh1uZZ2S8C8E9yI0lFXDHYmc+bO5dAvfpGyvfaivLy81c83NDTwzttvAzBy5Eg9aFTaY4m7j2vaqRCI2eQpU7jtttta/Vx1dTUHjRzZARVJHlMIiAQubQjoFKFI4BQCIoFTCIgETiEgEjiFgEjgFAIigVMIiAROISASOIWASOAUAiKBazUEosFFNpvZ6yl9vzezV6K21sxeifoHm1ldynt35LJ4EclcW54nMAu4Bbh3V4e7//uuaTO7EXg/5fOr3X1MtgoUkdxqy9OGnzOzwenes+Rzrk4FvprdskSko2R6TGA8sMnd30rpG2JmL5vZX8ys/Y/UFZEOkenjxU4H5qa83gjs5+41ZvYF4GEzG+nuHzSd0cymAlMz/H4RyVC7twTMrAg4Gfj9rj5PDkleE00vAVYDw9LN7+4z3H1cuvubRaTjZLI7cAzwN3dfv6vDzMrNrDCa3p/kWIRrMitRRHKpLacI5wIvAMPNbL2ZnRW9dRq77woAfBlYFp0y/F/g++6+JZsFi0h26fFiIuHQ48VEpDmFgEjgFAIigVMIiAROISASOIWASOAUAiKBUwiIBE4hIBI4hYBI4BQCIoFTCIgETiEgEjiFgEjgMn28WLa8B3wc/cxnfcjvdcz39YOuvY7/kq6zUzxPAMDMFuf7o8byfR3zff0gP9dRuwMigVMIiASuM4XAjLgL6AD5vo75vn6Qh+vYaY4JiEg8OtOWgIjEIPYQMLOJZrbCzFaZ2SVx15Mt0WjNr0WjMy+O+vYxs6fN7K3o595x17knWhihOu06WdJvot/rMjMbG1/lbdPC+l1lZhtSRto+IeW9n0brt8LMjoun6szFGgLRQCW3AscDlcDpZlYZZ01ZdpS7j0k5pXQJMN/dhwLzo9ddySxgYpO+ltbpeJKDzwwlOdzc7R1UYyZm0Xz9AG6Kfo9j3P0xgOjv9DRgZDTPbbsG3ulq4t4SOBRY5e5r3H07MA+oirmmXKoCZkfTs4FvxljLHnP354Cmg8m0tE5VwL2etBDobWb9O6bS9mlh/VpSBcyLht6rBlaR/HvucuIOgQHAOymv10d9+cCBp8xsSTT4KkCFu2+Mpt8FKuIpLataWqd8+t1Oi3ZpZqbswuXN+sUdAvnsS+4+luRm8Xlm9uXUNz15WiavTs3k4zqR3I05ABhDctTtG+MtJ/viDoENwKCU1wOjvi7P3TdEPzcDD5HcVNy0a5M4+rk5vgqzpqV1yovfrbtvcved7t4I3MU/N/nzYv0g/hB4CRhqZkPMLEHyQMujMdeUMTPrYWZlu6aBY4HXSa7b5Ohjk4FH4qkwq1pap0eB70ZnCQ4H3k/ZbegymhzHmETy9wjJ9TvNzIrNbAjJA6CLOrq+bIj1LkJ332Fm04AngUJgprsvj7OmLKkAHjIzSP43nuPuT5jZS8D90cjO64BTY6xxj0UjVE8A+pjZemA68AvSr9NjwAkkD5jVAmd0eMF7qIX1m2BmY0ju5qwFzgFw9+Vmdj/wBrADOM/dd8ZRd6Z0xaBI4OLeHRCRmCkERAKnEBAJnEJAJHAKAZHAKQREAqcQEAmcQkAkcP8PWIfaLw0qe5EAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E7XRZIKtCN8E",
        "colab_type": "text"
      },
      "source": [
        "# Define a UNet module"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b8EJl0hcC5DH",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch.nn as nn\n",
        "import torchvision.models\n",
        "\n",
        "\n",
        "def convrelu(in_channels, out_channels, kernel, padding):\n",
        "  return nn.Sequential(\n",
        "    nn.Conv2d(in_channels, out_channels, kernel, padding=padding),\n",
        "    nn.ReLU(inplace=True),\n",
        "  )\n",
        "\n",
        "\n",
        "class ResNetUNet(nn.Module):\n",
        "  def __init__(self, n_class):\n",
        "    super().__init__()\n",
        "\n",
        "    self.base_model = torchvision.models.resnet18(pretrained=True)\n",
        "    self.base_layers = list(self.base_model.children())\n",
        "\n",
        "    self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)\n",
        "    self.layer0_1x1 = convrelu(64, 64, 1, 0)\n",
        "    self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)\n",
        "    self.layer1_1x1 = convrelu(64, 64, 1, 0)\n",
        "    self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)\n",
        "    self.layer2_1x1 = convrelu(128, 128, 1, 0)\n",
        "    self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)\n",
        "    self.layer3_1x1 = convrelu(256, 256, 1, 0)\n",
        "    self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)\n",
        "    self.layer4_1x1 = convrelu(512, 512, 1, 0)\n",
        "\n",
        "    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n",
        "\n",
        "    self.conv_up3 = convrelu(256 + 512, 512, 3, 1)\n",
        "    self.conv_up2 = convrelu(128 + 512, 256, 3, 1)\n",
        "    self.conv_up1 = convrelu(64 + 256, 256, 3, 1)\n",
        "    self.conv_up0 = convrelu(64 + 256, 128, 3, 1)\n",
        "\n",
        "    self.conv_original_size0 = convrelu(3, 64, 3, 1)\n",
        "    self.conv_original_size1 = convrelu(64, 64, 3, 1)\n",
        "    self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)\n",
        "\n",
        "    self.conv_last = nn.Conv2d(64, n_class, 1)\n",
        "\n",
        "  def forward(self, input):\n",
        "    x_original = self.conv_original_size0(input)\n",
        "    x_original = self.conv_original_size1(x_original)\n",
        "\n",
        "    layer0 = self.layer0(input)\n",
        "    layer1 = self.layer1(layer0)\n",
        "    layer2 = self.layer2(layer1)\n",
        "    layer3 = self.layer3(layer2)\n",
        "    layer4 = self.layer4(layer3)\n",
        "\n",
        "    layer4 = self.layer4_1x1(layer4)\n",
        "    x = self.upsample(layer4)\n",
        "    layer3 = self.layer3_1x1(layer3)\n",
        "    x = torch.cat([x, layer3], dim=1)\n",
        "    x = self.conv_up3(x)\n",
        "\n",
        "    x = self.upsample(x)\n",
        "    layer2 = self.layer2_1x1(layer2)\n",
        "    x = torch.cat([x, layer2], dim=1)\n",
        "    x = self.conv_up2(x)\n",
        "\n",
        "    x = self.upsample(x)\n",
        "    layer1 = self.layer1_1x1(layer1)\n",
        "    x = torch.cat([x, layer1], dim=1)\n",
        "    x = self.conv_up1(x)\n",
        "\n",
        "    x = self.upsample(x)\n",
        "    layer0 = self.layer0_1x1(layer0)\n",
        "    x = torch.cat([x, layer0], dim=1)\n",
        "    x = self.conv_up0(x)\n",
        "\n",
        "    x = self.upsample(x)\n",
        "    x = torch.cat([x, x_original], dim=1)\n",
        "    x = self.conv_original_size2(x)\n",
        "\n",
        "    out = self.conv_last(x)\n",
        "\n",
        "    return out"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gJ65Br1oDCOX",
        "colab_type": "text"
      },
      "source": [
        "## Instantiate the UNet model\n",
        "\n",
        "- Move the model to GPU if available\n",
        "- Show model summaries"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bY0Vk2VDCAiz",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "bd4cd640-2b9d-4d89-9a66-aa421d14c7cd"
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import pytorch_unet\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print('device', device)\n",
        "\n",
        "model = ResNetUNet(6)\n",
        "model = model.to(device)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "device cuda\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RaZdFgOnGA_p",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "c0fdf0a7-8b93-49d9-81e7-8b8cb5ad30e2"
      },
      "source": [
        "model"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "ResNetUNet(\n",
              "  (base_model): ResNet(\n",
              "    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
              "    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    (relu): ReLU(inplace=True)\n",
              "    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
              "    (layer1): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (layer2): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (downsample): Sequential(\n",
              "          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        )\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (layer3): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (downsample): Sequential(\n",
              "          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        )\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (layer4): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (downsample): Sequential(\n",
              "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        )\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
              "    (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
              "  )\n",
              "  (layer0): Sequential(\n",
              "    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
              "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    (2): ReLU(inplace=True)\n",
              "  )\n",
              "  (layer0_1x1): Sequential(\n",
              "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (layer1): Sequential(\n",
              "    (0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
              "    (1): Sequential(\n",
              "      (0): BasicBlock(\n",
              "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "      (1): BasicBlock(\n",
              "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "        (relu): ReLU(inplace=True)\n",
              "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "  )\n",
              "  (layer1_1x1): Sequential(\n",
              "    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (layer2): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (downsample): Sequential(\n",
              "        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "  )\n",
              "  (layer2_1x1): Sequential(\n",
              "    (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (layer3): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (downsample): Sequential(\n",
              "        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "  )\n",
              "  (layer3_1x1): Sequential(\n",
              "    (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (layer4): Sequential(\n",
              "    (0): BasicBlock(\n",
              "      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (downsample): Sequential(\n",
              "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
              "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      )\n",
              "    )\n",
              "    (1): BasicBlock(\n",
              "      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (relu): ReLU(inplace=True)\n",
              "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
              "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "    )\n",
              "  )\n",
              "  (layer4_1x1): Sequential(\n",
              "    (0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (upsample): Upsample(scale_factor=2.0, mode=bilinear)\n",
              "  (conv_up3): Sequential(\n",
              "    (0): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (conv_up2): Sequential(\n",
              "    (0): Conv2d(640, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (conv_up1): Sequential(\n",
              "    (0): Conv2d(320, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (conv_up0): Sequential(\n",
              "    (0): Conv2d(320, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (conv_original_size0): Sequential(\n",
              "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (conv_original_size1): Sequential(\n",
              "    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (conv_original_size2): Sequential(\n",
              "    (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
              "    (1): ReLU(inplace=True)\n",
              "  )\n",
              "  (conv_last): Conv2d(64, 6, kernel_size=(1, 1), stride=(1, 1))\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MoVYhHpbCSdY",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "e6f9b882-fed1-4343-fbe7-d655732c0b2d"
      },
      "source": [
        "from torchsummary import summary\n",
        "summary(model, input_size=(3, 224, 224))"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "----------------------------------------------------------------\n",
            "        Layer (type)               Output Shape         Param #\n",
            "================================================================\n",
            "            Conv2d-1         [-1, 64, 224, 224]           1,792\n",
            "              ReLU-2         [-1, 64, 224, 224]               0\n",
            "            Conv2d-3         [-1, 64, 224, 224]          36,928\n",
            "              ReLU-4         [-1, 64, 224, 224]               0\n",
            "            Conv2d-5         [-1, 64, 112, 112]           9,408\n",
            "            Conv2d-6         [-1, 64, 112, 112]           9,408\n",
            "       BatchNorm2d-7         [-1, 64, 112, 112]             128\n",
            "       BatchNorm2d-8         [-1, 64, 112, 112]             128\n",
            "              ReLU-9         [-1, 64, 112, 112]               0\n",
            "             ReLU-10         [-1, 64, 112, 112]               0\n",
            "        MaxPool2d-11           [-1, 64, 56, 56]               0\n",
            "        MaxPool2d-12           [-1, 64, 56, 56]               0\n",
            "           Conv2d-13           [-1, 64, 56, 56]          36,864\n",
            "           Conv2d-14           [-1, 64, 56, 56]          36,864\n",
            "      BatchNorm2d-15           [-1, 64, 56, 56]             128\n",
            "      BatchNorm2d-16           [-1, 64, 56, 56]             128\n",
            "             ReLU-17           [-1, 64, 56, 56]               0\n",
            "             ReLU-18           [-1, 64, 56, 56]               0\n",
            "           Conv2d-19           [-1, 64, 56, 56]          36,864\n",
            "           Conv2d-20           [-1, 64, 56, 56]          36,864\n",
            "      BatchNorm2d-21           [-1, 64, 56, 56]             128\n",
            "      BatchNorm2d-22           [-1, 64, 56, 56]             128\n",
            "             ReLU-23           [-1, 64, 56, 56]               0\n",
            "             ReLU-24           [-1, 64, 56, 56]               0\n",
            "       BasicBlock-25           [-1, 64, 56, 56]               0\n",
            "       BasicBlock-26           [-1, 64, 56, 56]               0\n",
            "           Conv2d-27           [-1, 64, 56, 56]          36,864\n",
            "           Conv2d-28           [-1, 64, 56, 56]          36,864\n",
            "      BatchNorm2d-29           [-1, 64, 56, 56]             128\n",
            "      BatchNorm2d-30           [-1, 64, 56, 56]             128\n",
            "             ReLU-31           [-1, 64, 56, 56]               0\n",
            "             ReLU-32           [-1, 64, 56, 56]               0\n",
            "           Conv2d-33           [-1, 64, 56, 56]          36,864\n",
            "           Conv2d-34           [-1, 64, 56, 56]          36,864\n",
            "      BatchNorm2d-35           [-1, 64, 56, 56]             128\n",
            "      BatchNorm2d-36           [-1, 64, 56, 56]             128\n",
            "             ReLU-37           [-1, 64, 56, 56]               0\n",
            "             ReLU-38           [-1, 64, 56, 56]               0\n",
            "       BasicBlock-39           [-1, 64, 56, 56]               0\n",
            "       BasicBlock-40           [-1, 64, 56, 56]               0\n",
            "           Conv2d-41          [-1, 128, 28, 28]          73,728\n",
            "           Conv2d-42          [-1, 128, 28, 28]          73,728\n",
            "      BatchNorm2d-43          [-1, 128, 28, 28]             256\n",
            "      BatchNorm2d-44          [-1, 128, 28, 28]             256\n",
            "             ReLU-45          [-1, 128, 28, 28]               0\n",
            "             ReLU-46          [-1, 128, 28, 28]               0\n",
            "           Conv2d-47          [-1, 128, 28, 28]         147,456\n",
            "           Conv2d-48          [-1, 128, 28, 28]         147,456\n",
            "      BatchNorm2d-49          [-1, 128, 28, 28]             256\n",
            "      BatchNorm2d-50          [-1, 128, 28, 28]             256\n",
            "           Conv2d-51          [-1, 128, 28, 28]           8,192\n",
            "           Conv2d-52          [-1, 128, 28, 28]           8,192\n",
            "      BatchNorm2d-53          [-1, 128, 28, 28]             256\n",
            "      BatchNorm2d-54          [-1, 128, 28, 28]             256\n",
            "             ReLU-55          [-1, 128, 28, 28]               0\n",
            "             ReLU-56          [-1, 128, 28, 28]               0\n",
            "       BasicBlock-57          [-1, 128, 28, 28]               0\n",
            "       BasicBlock-58          [-1, 128, 28, 28]               0\n",
            "           Conv2d-59          [-1, 128, 28, 28]         147,456\n",
            "           Conv2d-60          [-1, 128, 28, 28]         147,456\n",
            "      BatchNorm2d-61          [-1, 128, 28, 28]             256\n",
            "      BatchNorm2d-62          [-1, 128, 28, 28]             256\n",
            "             ReLU-63          [-1, 128, 28, 28]               0\n",
            "             ReLU-64          [-1, 128, 28, 28]               0\n",
            "           Conv2d-65          [-1, 128, 28, 28]         147,456\n",
            "           Conv2d-66          [-1, 128, 28, 28]         147,456\n",
            "      BatchNorm2d-67          [-1, 128, 28, 28]             256\n",
            "      BatchNorm2d-68          [-1, 128, 28, 28]             256\n",
            "             ReLU-69          [-1, 128, 28, 28]               0\n",
            "             ReLU-70          [-1, 128, 28, 28]               0\n",
            "       BasicBlock-71          [-1, 128, 28, 28]               0\n",
            "       BasicBlock-72          [-1, 128, 28, 28]               0\n",
            "           Conv2d-73          [-1, 256, 14, 14]         294,912\n",
            "           Conv2d-74          [-1, 256, 14, 14]         294,912\n",
            "      BatchNorm2d-75          [-1, 256, 14, 14]             512\n",
            "      BatchNorm2d-76          [-1, 256, 14, 14]             512\n",
            "             ReLU-77          [-1, 256, 14, 14]               0\n",
            "             ReLU-78          [-1, 256, 14, 14]               0\n",
            "           Conv2d-79          [-1, 256, 14, 14]         589,824\n",
            "           Conv2d-80          [-1, 256, 14, 14]         589,824\n",
            "      BatchNorm2d-81          [-1, 256, 14, 14]             512\n",
            "      BatchNorm2d-82          [-1, 256, 14, 14]             512\n",
            "           Conv2d-83          [-1, 256, 14, 14]          32,768\n",
            "           Conv2d-84          [-1, 256, 14, 14]          32,768\n",
            "      BatchNorm2d-85          [-1, 256, 14, 14]             512\n",
            "      BatchNorm2d-86          [-1, 256, 14, 14]             512\n",
            "             ReLU-87          [-1, 256, 14, 14]               0\n",
            "             ReLU-88          [-1, 256, 14, 14]               0\n",
            "       BasicBlock-89          [-1, 256, 14, 14]               0\n",
            "       BasicBlock-90          [-1, 256, 14, 14]               0\n",
            "           Conv2d-91          [-1, 256, 14, 14]         589,824\n",
            "           Conv2d-92          [-1, 256, 14, 14]         589,824\n",
            "      BatchNorm2d-93          [-1, 256, 14, 14]             512\n",
            "      BatchNorm2d-94          [-1, 256, 14, 14]             512\n",
            "             ReLU-95          [-1, 256, 14, 14]               0\n",
            "             ReLU-96          [-1, 256, 14, 14]               0\n",
            "           Conv2d-97          [-1, 256, 14, 14]         589,824\n",
            "           Conv2d-98          [-1, 256, 14, 14]         589,824\n",
            "      BatchNorm2d-99          [-1, 256, 14, 14]             512\n",
            "     BatchNorm2d-100          [-1, 256, 14, 14]             512\n",
            "            ReLU-101          [-1, 256, 14, 14]               0\n",
            "            ReLU-102          [-1, 256, 14, 14]               0\n",
            "      BasicBlock-103          [-1, 256, 14, 14]               0\n",
            "      BasicBlock-104          [-1, 256, 14, 14]               0\n",
            "          Conv2d-105            [-1, 512, 7, 7]       1,179,648\n",
            "          Conv2d-106            [-1, 512, 7, 7]       1,179,648\n",
            "     BatchNorm2d-107            [-1, 512, 7, 7]           1,024\n",
            "     BatchNorm2d-108            [-1, 512, 7, 7]           1,024\n",
            "            ReLU-109            [-1, 512, 7, 7]               0\n",
            "            ReLU-110            [-1, 512, 7, 7]               0\n",
            "          Conv2d-111            [-1, 512, 7, 7]       2,359,296\n",
            "          Conv2d-112            [-1, 512, 7, 7]       2,359,296\n",
            "     BatchNorm2d-113            [-1, 512, 7, 7]           1,024\n",
            "     BatchNorm2d-114            [-1, 512, 7, 7]           1,024\n",
            "          Conv2d-115            [-1, 512, 7, 7]         131,072\n",
            "          Conv2d-116            [-1, 512, 7, 7]         131,072\n",
            "     BatchNorm2d-117            [-1, 512, 7, 7]           1,024\n",
            "     BatchNorm2d-118            [-1, 512, 7, 7]           1,024\n",
            "            ReLU-119            [-1, 512, 7, 7]               0\n",
            "            ReLU-120            [-1, 512, 7, 7]               0\n",
            "      BasicBlock-121            [-1, 512, 7, 7]               0\n",
            "      BasicBlock-122            [-1, 512, 7, 7]               0\n",
            "          Conv2d-123            [-1, 512, 7, 7]       2,359,296\n",
            "          Conv2d-124            [-1, 512, 7, 7]       2,359,296\n",
            "     BatchNorm2d-125            [-1, 512, 7, 7]           1,024\n",
            "     BatchNorm2d-126            [-1, 512, 7, 7]           1,024\n",
            "            ReLU-127            [-1, 512, 7, 7]               0\n",
            "            ReLU-128            [-1, 512, 7, 7]               0\n",
            "          Conv2d-129            [-1, 512, 7, 7]       2,359,296\n",
            "          Conv2d-130            [-1, 512, 7, 7]       2,359,296\n",
            "     BatchNorm2d-131            [-1, 512, 7, 7]           1,024\n",
            "     BatchNorm2d-132            [-1, 512, 7, 7]           1,024\n",
            "            ReLU-133            [-1, 512, 7, 7]               0\n",
            "            ReLU-134            [-1, 512, 7, 7]               0\n",
            "      BasicBlock-135            [-1, 512, 7, 7]               0\n",
            "      BasicBlock-136            [-1, 512, 7, 7]               0\n",
            "          Conv2d-137            [-1, 512, 7, 7]         262,656\n",
            "            ReLU-138            [-1, 512, 7, 7]               0\n",
            "        Upsample-139          [-1, 512, 14, 14]               0\n",
            "          Conv2d-140          [-1, 256, 14, 14]          65,792\n",
            "            ReLU-141          [-1, 256, 14, 14]               0\n",
            "          Conv2d-142          [-1, 512, 14, 14]       3,539,456\n",
            "            ReLU-143          [-1, 512, 14, 14]               0\n",
            "        Upsample-144          [-1, 512, 28, 28]               0\n",
            "          Conv2d-145          [-1, 128, 28, 28]          16,512\n",
            "            ReLU-146          [-1, 128, 28, 28]               0\n",
            "          Conv2d-147          [-1, 256, 28, 28]       1,474,816\n",
            "            ReLU-148          [-1, 256, 28, 28]               0\n",
            "        Upsample-149          [-1, 256, 56, 56]               0\n",
            "          Conv2d-150           [-1, 64, 56, 56]           4,160\n",
            "            ReLU-151           [-1, 64, 56, 56]               0\n",
            "          Conv2d-152          [-1, 256, 56, 56]         737,536\n",
            "            ReLU-153          [-1, 256, 56, 56]               0\n",
            "        Upsample-154        [-1, 256, 112, 112]               0\n",
            "          Conv2d-155         [-1, 64, 112, 112]           4,160\n",
            "            ReLU-156         [-1, 64, 112, 112]               0\n",
            "          Conv2d-157        [-1, 128, 112, 112]         368,768\n",
            "            ReLU-158        [-1, 128, 112, 112]               0\n",
            "        Upsample-159        [-1, 128, 224, 224]               0\n",
            "          Conv2d-160         [-1, 64, 224, 224]         110,656\n",
            "            ReLU-161         [-1, 64, 224, 224]               0\n",
            "          Conv2d-162          [-1, 6, 224, 224]             390\n",
            "================================================================\n",
            "Total params: 28,976,646\n",
            "Trainable params: 28,976,646\n",
            "Non-trainable params: 0\n",
            "----------------------------------------------------------------\n",
            "Input size (MB): 0.57\n",
            "Forward/backward pass size (MB): 417.65\n",
            "Params size (MB): 110.54\n",
            "Estimated Total Size (MB): 528.76\n",
            "----------------------------------------------------------------\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H7rAEQCUEI2v",
        "colab_type": "text"
      },
      "source": [
        "# Define the main training loop"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tjt9JeTuDY6D",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from collections import defaultdict\n",
        "import torch.nn.functional as F\n",
        "from loss import dice_loss\n",
        "\n",
        "checkpoint_path = \"checkpoint.pth\"\n",
        "\n",
        "def calc_loss(pred, target, metrics, bce_weight=0.5):\n",
        "    bce = F.binary_cross_entropy_with_logits(pred, target)\n",
        "\n",
        "    pred = torch.sigmoid(pred)\n",
        "    dice = dice_loss(pred, target)\n",
        "\n",
        "    loss = bce * bce_weight + dice * (1 - bce_weight)\n",
        "\n",
        "    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)\n",
        "    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)\n",
        "    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)\n",
        "\n",
        "    return loss\n",
        "\n",
        "def print_metrics(metrics, epoch_samples, phase):\n",
        "    outputs = []\n",
        "    for k in metrics.keys():\n",
        "        outputs.append(\"{}: {:4f}\".format(k, metrics[k] / epoch_samples))\n",
        "\n",
        "    print(\"{}: {}\".format(phase, \", \".join(outputs)))\n",
        "\n",
        "def train_model(model, optimizer, scheduler, num_epochs=25):\n",
        "    best_loss = 1e10\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
        "        print('-' * 10)\n",
        "\n",
        "        since = time.time()\n",
        "\n",
        "        # Each epoch has a training and validation phase\n",
        "        for phase in ['train', 'val']:\n",
        "            if phase == 'train':\n",
        "                model.train()  # Set model to training mode\n",
        "            else:\n",
        "                model.eval()   # Set model to evaluate mode\n",
        "\n",
        "            metrics = defaultdict(float)\n",
        "            epoch_samples = 0\n",
        "\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                inputs = inputs.to(device)\n",
        "                labels = labels.to(device)\n",
        "\n",
        "                # zero the parameter gradients\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # forward\n",
        "                # track history if only in train\n",
        "                with torch.set_grad_enabled(phase == 'train'):\n",
        "                    outputs = model(inputs)\n",
        "                    loss = calc_loss(outputs, labels, metrics)\n",
        "\n",
        "                    # backward + optimize only if in training phase\n",
        "                    if phase == 'train':\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                # statistics\n",
        "                epoch_samples += inputs.size(0)\n",
        "\n",
        "            print_metrics(metrics, epoch_samples, phase)\n",
        "            epoch_loss = metrics['loss'] / epoch_samples\n",
        "\n",
        "            if phase == 'train':\n",
        "              scheduler.step()\n",
        "              for param_group in optimizer.param_groups:\n",
        "                  print(\"LR\", param_group['lr'])\n",
        "\n",
        "            # save the model weights\n",
        "            if phase == 'val' and epoch_loss < best_loss:\n",
        "                print(f\"saving best model to {checkpoint_path}\")\n",
        "                best_loss = epoch_loss\n",
        "                torch.save(model.state_dict(), checkpoint_path)\n",
        "\n",
        "        time_elapsed = time.time() - since\n",
        "        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n",
        "\n",
        "    print('Best val loss: {:4f}'.format(best_loss))\n",
        "\n",
        "    # load best model weights\n",
        "    model.load_state_dict(torch.load(checkpoint_path))\n",
        "    return model"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "adcdAu9ZEOLG",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RfxgL303EMiy",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "9e0a2ff2-2ebc-4484-d845-f229f5764a53"
      },
      "source": [
        "import torch\n",
        "import torch.optim as optim\n",
        "from torch.optim import lr_scheduler\n",
        "import time\n",
        "\n",
        "num_class = 6\n",
        "model = ResNetUNet(num_class).to(device)\n",
        "\n",
        "# freeze backbone layers\n",
        "for l in model.base_layers:\n",
        "  for param in l.parameters():\n",
        "    param.requires_grad = False\n",
        "\n",
        "optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)\n",
        "\n",
        "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=8, gamma=0.1)\n",
        "\n",
        "model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=10)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0/9\n",
            "----------\n",
            "train: bce: 0.108924, dice: 0.968998, loss: 0.538961\n",
            "LR 0.0001\n",
            "val: bce: 0.021421, dice: 0.818291, loss: 0.419856\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 1/9\n",
            "----------\n",
            "train: bce: 0.015206, dice: 0.599171, loss: 0.307188\n",
            "LR 0.0001\n",
            "val: bce: 0.007349, dice: 0.340564, loss: 0.173956\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 2/9\n",
            "----------\n",
            "train: bce: 0.004170, dice: 0.183031, loss: 0.093601\n",
            "LR 0.0001\n",
            "val: bce: 0.002601, dice: 0.100720, loss: 0.051660\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 3/9\n",
            "----------\n",
            "train: bce: 0.002347, dice: 0.078547, loss: 0.040447\n",
            "LR 0.0001\n",
            "val: bce: 0.002113, dice: 0.069696, loss: 0.035905\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 4/9\n",
            "----------\n",
            "train: bce: 0.001875, dice: 0.057423, loss: 0.029649\n",
            "LR 0.0001\n",
            "val: bce: 0.001755, dice: 0.057333, loss: 0.029544\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 5/9\n",
            "----------\n",
            "train: bce: 0.001587, dice: 0.047364, loss: 0.024475\n",
            "LR 0.0001\n",
            "val: bce: 0.001657, dice: 0.051201, loss: 0.026429\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 6/9\n",
            "----------\n",
            "train: bce: 0.001377, dice: 0.040896, loss: 0.021137\n",
            "LR 0.0001\n",
            "val: bce: 0.001525, dice: 0.046798, loss: 0.024161\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 7/9\n",
            "----------\n",
            "train: bce: 0.001265, dice: 0.036721, loss: 0.018993\n",
            "LR 1e-05\n",
            "val: bce: 0.001378, dice: 0.042819, loss: 0.022099\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 8/9\n",
            "----------\n",
            "train: bce: 0.001181, dice: 0.033613, loss: 0.017397\n",
            "LR 1e-05\n",
            "val: bce: 0.001392, dice: 0.042263, loss: 0.021827\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Epoch 9/9\n",
            "----------\n",
            "train: bce: 0.001152, dice: 0.033016, loss: 0.017084\n",
            "LR 1e-05\n",
            "val: bce: 0.001415, dice: 0.042124, loss: 0.021769\n",
            "saving best model to checkpoint.pth\n",
            "0m 21s\n",
            "Best val loss: 0.021769\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lcRgjfk5D-kP",
        "colab_type": "text"
      },
      "source": [
        "## Predict new images using the trained model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xXRtpxHRET-v",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "285f06ae-2502-4402-a79e-b0564933c425"
      },
      "source": [
        "import math\n",
        "\n",
        "model.eval()   # Set model to the evaluation mode\n",
        "\n",
        "# Create a new simulation dataset for testing\n",
        "test_dataset = SimDataset(3, transform = trans)\n",
        "test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)\n",
        "\n",
        "# Get the first batch\n",
        "inputs, labels = next(iter(test_loader))\n",
        "inputs = inputs.to(device)\n",
        "labels = labels.to(device)\n",
        "print('inputs.shape', inputs.shape)\n",
        "print('labels.shape', labels.shape)\n",
        "\n",
        "# Predict\n",
        "pred = model(inputs)\n",
        "# The loss functions include the sigmoid function.\n",
        "pred = torch.sigmoid(pred)\n",
        "pred = pred.data.cpu().numpy()\n",
        "print('pred.shape', pred.shape)\n",
        "\n",
        "# Change channel-order and make 3 channels for matplot\n",
        "input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]\n",
        "\n",
        "# Map each channel (i.e. class) to each color\n",
        "target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]\n",
        "pred_rgb = [helper.masks_to_colorimg(x) for x in pred]"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "inputs.shape torch.Size([3, 3, 192, 192])\n",
            "labels.shape torch.Size([3, 6, 192, 192])\n",
            "pred.shape (3, 6, 192, 192)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XPQyJc4YD39T",
        "colab_type": "text"
      },
      "source": [
        "## Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z6dkJZLBCv4t",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 704
        },
        "outputId": "628f669d-0b36-4cb6-d2f4-2c971dcf57e3"
      },
      "source": [
        "helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsEAAAKvCAYAAACLTxJeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdfaykd3kf/O+1u0n+ACpIOLGMzda8GKIkDZuycis38JAALdAAASSKVaUkQV2jgtS8PGoJPGrSPgqqUgh6qughWYQFlhIHGkNCIqeJY/GEEkzLOnEcQ3BiO0bYOPYGp4G8iNZ7ruePnWOP1+dlzpmZMy/35yONzsxv5p655ngv/77nN/fcd3V3AABgSI4sugAAADhsQjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDM7cQXFUvq6o7qurOqnrbvF4HAAD2q+ZxnOCqOprkj5O8NMm9ST6T5Kru/tzMXwwAAPZpXivBVyS5s7vv7u7/leSXkrx6Tq8FAAD7cmxOz3tJki+O3b43yT/Y6cFV5bR1DNmfd/fGoovYj6c+9al92WWXLboMWIhbbrllpXpWvzJku/XrvELwnqrqVJJTi3p9WCJfWHQBkxjv2ePHj+fMmTMLrggWo6qWvmf1K5y3W7/Oa3eI+5I8fez2paOxR3T36e4+2d0n51QDMEPjPbuxsTKLYDBI+hX2Nq8Q/Jkkl1fVM6rq65O8IcnH5vRaAACwL3PZHaK7H66qtyb5zSRHk1zT3Z+dx2sBAMB+zW2f4O6+IckN83p+AAA4KGeMAwBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAbnwCG4qp5eVR+vqs9V1Wer6l+Pxn+yqu6rqltHl1fMrlwAAJjesSm2fTjJj3X371XVk5LcUlU3ju57T3e/a/ryAABg9g4cgrv7/iT3j65/tar+KMklsyoMAADmZSb7BFfVZUm+M8l/Hw29tapuq6prquops3gNAACYlalDcFU9Mcn1SX64u7+S5L1JnpXkRM6vFL97h+1OVdWZqjozbQ3A/I337NmzZxddDrAL/Qp7myoEV9XX5XwA/oXu/kiSdPcD3X2uuzeTvC/JFdtt292nu/tkd5+cpgbgcIz37MbGxqLLYQ/dne5edBksiH5dLfp1MaY5OkQleX+SP+runxkbv3jsYa9JcvvBy2NZbHbn3ObmossAJtDdufaTV+baT1656FKAPejXxZnm6BD/KMn3J/nDqrp1NPb2JFdV1YkkneSeJFdPVSEAAMzYNEeH+GSS2uauGw5eDgAAzJ8zxgEAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgzPNGeMAOCTdfSjbVW13DiRgP/bdr52k9redXp2eEDxQmweYUKtqX9t1d44e8WEDTKu7c+0nrzzQtvvd7o0vuPlArwOct+9+7eSbbjqaL7/43KPbdfLUm4497nFJHjlX7z/99/9t6lqHTkIZqO7e1+Ug2wEAe6uuR0PvdgE4ycZNx7KxzTgH57c5UPtdod0cBVsru3D4qmpfK7TjK1FWduFwHaRfP/Pj35NkbHX3hds88N/PoDgeQwgGWCOPfArTF9yegH0MYbH06+ESgtnV+D7A+9kn2KoxHL7uzme+63vS6Tz1JceSziMrTJO44nc/PsfqgAtt9eyW8et70a/Tk1IAAA7ZhQGYwycEAwAwOFPvDlFV9yT5apJzSR7u7pNV9Y1JPpTksiT3JHl9d//FtK8FAACzMKuV4O/u7hPdfXJ0+21Jburuy5PcNLoNAABLYV67Q7w6yQdH1z+Y5Pvm9DoAALBvswjBneS3quqWqjo1Gruou+8fXf+zJBfN4HUAAGAmZnGItO/q7vuq6puT3FhVnx+/s7u7qh53XK1RYD514TiwnMZ79vjx4wuuBtiNfoW9Tb0S3N33jX4+mOSjSa5I8kBVXZwko58PbrPd6e4+ObYfMbDExnt2Y2Nj0eUAu9CvsLepQnBVPaGqnrR1Pck/TnJ7ko8leePoYW9M8qvTvA4A+zT5iacABmna3SEuSvLR0an7jiX5xe7+r1X1mSQfrqo3JflCktdP+ToATKhS2bjJCUEBdjPV/yW7++4kz9tm/MtJXjzNcwMAwLw4YxwAAIMjBAMAMDh2GmNX3b5dAwCsHyGYXR094sMCWBVVlSt+9+OLLgOYgH5dPAkHAIDBEYIBABgcIRgAgMERggEAGBwhGACAwRGCAQAYHCEYAIDBEYIBABgcIRgAgMERggEAGBwhGACAwRGCAQAYnGMH3bCqnpvkQ2NDz0zy75I8Ocm/THJ2NP727r7hwBUCAMCMHTgEd/cdSU4kSVUdTXJfko8m+cEk7+nud82kQgAAmLFZ7Q7x4iR3dfcXZvR8AAAwN7MKwW9Ict3Y7bdW1W1VdU1VPWVGrwEAADMxdQiuqq9P8qok/2U09N4kz8r5XSXuT/LuHbY7VVVnqurMtDUA8zfes2fPnt17A2Bh9CvsbRYrwS9P8nvd/UCSdPcD3X2uuzeTvC/JFdtt1N2nu/tkd5+cQQ3AnI337MbGxqLLAXahX2FvswjBV2VsV4iqunjsvtckuX0GrwEAADNz4KNDJElVPSHJS5NcPTb801V1IkknueeC+wAAYOGmCsHd/ddJvumCse+fqiIAAJgzZ4wDAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRggH3q7kWXAMCUhOAlcG5zc9ElABPq7hw5/UpBGFZEd+tXtiUEL4GqEoRhxQjCsPy2/mjVr2xHCF4SgjCsHhMrrA79yoWE4CUiCMPqMbHC6tCvjJsoBFfVNVX1YFXdPjb2jVV1Y1X9yejnU0bjVVX/uarurKrbqurvz6v4dSQIw+oxscLq0K9smXQl+ANJXnbB2NuS3NTdlye5aXQ7SV6e5PLR5VSS905f5rAIwrB6TKywOvQryYQhuLs/keShC4ZfneSDo+sfTPJ9Y+PX9nmfTvLkqrp4FsUOiSAMq8fECqtDvzLNPsEXdff9o+t/luSi0fVLknxx7HH3jsbYJ0EYVo+JFVaHfh22mXwxrs//C9rXv6KqOlVVZ6rqzCxqWFeCMMtivGfPnj276HKWmomVRdOvk9OvwzVNCH5gazeH0c8HR+P3JXn62OMuHY09Rnef7u6T3X1yihoGQRBmGYz37MbGxqLLWXomVhZJv+6Pfh2maULwx5K8cXT9jUl+dWz8X4yOEvEPk/zl2G4THJAgDKvHxAqrQ78Oz6SHSLsuyc1JnltV91bVm5L8xyQvrao/SfKS0e0kuSHJ3UnuTPK+JP9q5lUPlCAMq8fECqtDvw7LsUke1N1X7XDXi7d5bCd5yzRFsbOtIHz0iPOcwKo4cvqV2Tz1a6mqRZcC7EG/DocktYKsCMPqscIEq0O/DoMQvKIEYVg9JlZYHfp1/QnBK0wQhtVjYoUl0JNd9Ot6m2ifYOZr2gazjzCsliOnX5m++tcXXQYMxwXT7KWfunTPx2y58qffmU/d+fYksZ/wmhGCl4AAC6ujqtJX//rUf7yOb29ihfmoqmye+rVc+ex3ptO578p7d3zspZ96+o73XfnsdyZJPnXn2/XrGhGCl8TWau6sfgLz092PTIqzcPNd75jZcwHbq9SuQXcSVz77nYLwGpGWlsRWcJ3VTwBg9mb5BzCLJTEtia0vuM3qJwAAOxOCl4SVYACYrVnvusR6kZiWhJVgAIDDIwQvCSvBAACHR2JaElaCAQAOjxC8JKwEAwAcHolpSVgJBgA4PELwkrASDABweCSmJWElGADg8OwZgqvqmqp6sKpuHxv7T1X1+aq6rao+WlVPHo1fVlV/W1W3ji4/N8/i14mVYACAwzNJYvpAkpddMHZjkm/v7u9I8sdJfnzsvru6+8To8ubZlLn+rAQDAByePUNwd38iyUMXjP1Wdz88uvnpJJfOobZBsRIMAHB4ZpGYfijJb4zdfkZV/X5V/U5VvWAGzz8IVoIBAA7PsWk2rqp3JHk4yS+Mhu5Pcry7v1xVz0/yK1X1bd39lW22PZXk1DSvv06sBLPsxnv2+PHjC64G2I1+hb0dODFV1Q8k+d4k/7y7O0m6+2vd/eXR9VuS3JXkOdtt392nu/tkd588aA3rxEowy268Zzc2NhZdDrAL/Qp7O1AIrqqXJfk3SV7V3X8zNr5RVUdH15+Z5PIkd8+i0HVnJRgA4PBMcoi065LcnOS5VXVvVb0pyc8meVKSGy84FNoLk9xWVbcm+eUkb+7uh7Z9Yh7DSjAALL+b73rHoktgRvbcJ7i7r9pm+P07PPb6JNdPW9QQWQmGxxrtZbWrqjqESrZ/XRMhPGpZ+3WrV7s7Vz77nVM/n75fL1N9MW6VbO7RoN0tQMKS6O781R8c3fNxT3zeuYUFYeC8VehXf7iynUGkvr0CcHK+QexKAIs36YSaJH/1B0cnWoEC5kO/ssrWPgRPEoC3CMKwWPuZULeYWGEx9Curbq1D8H4C8BZBGBbjIBPqFhMrHC79yjpY2xB8kAC8RRCGwzXNhLpl2u2ByehX1sXahmAAANiJEAwAwOAIwQAADI4QDADA4AjBAAAMjhAMAMDgCMEAAAzO2oZgB+IGAGAnaxuCjx45cuAg3N05emRtfzWwdKoqT3zeuame40knnOAGDoN+ZV2sddI7SBAWgGExpplYTahwuPQr62Dt095+grAADIt1kInVhAqLoV9ZdXsmvqq6pqoerKrbx8Z+sqruq6pbR5dXjN3341V1Z1XdUVX/ZF6F78dWEN7rIgDD4u1nYjWhwmLpV1bZsQke84EkP5vk2gvG39Pd7xofqKpvTfKGJN+W5GlJfruqntPd0+08NAMCLqyOqjJhworQr6yqPZNhd38iyUMTPt+rk/xSd3+tu/80yZ1JrpiiPgAAmLlplkffWlW3jXaXeMpo7JIkXxx7zL2jMQAAWBoHDcHvTfKsJCeS3J/k3ft9gqo6VVVnqurMAWsADtF4z549e3bR5QC70K+wtwOF4O5+oLvPdfdmkvfl0V0e7kvy9LGHXjoa2+45Tnf3ye4+eZAagMM13rMbGxuLLmdwJvlyL2zRr4ulX1fDJF+Me5yquri77x/dfE2SrSNHfCzJL1bVz+T8F+MuT/I/pq6SmdvcowEdLQOWw9ZkefVHvrTnY3/+tU9Lcv6LSsDh06+rZc8QXFXXJXlRkqdW1b1JfiLJi6rqRJJOck+Sq5Okuz9bVR9O8rkkDyd5yzIcGYJH7RV+t1RVNh06DhaquyeaTLdsPfbnX/s0EyscMv26eiY5OsRV3X1xd39dd1/a3e/v7u/v7r/X3d/R3a8aWxVOd/9Udz+ru5/b3b8x3/LZj+0C8F4f0VRVzm069A0ctv1OqOOu/siXfNwKh0i/rqYD7Q7B6rkwAG813HarvFuhd+sv060gbEUYDsduE+rp110y0WOv/siXrDDBIdCvq0uqGYDtAvDRI0d2DLVb943/ZWpFGA7HTpPk6ddd8rgJNTnfm6dfd8kj+xeOs8IE86VfV5sQvOZ2CsCTEIThcO02oe6lqnacWIHZ06+rTwgekIN8ye3CIAwcrkkm1C07Tax6GA6Hfl0tQvAaG18FnuYoD+PbWQ2G+dhuVWk/E+qW7SZWq0swW/p1PQjBTMRfpgDAOhGCB2AWx/p1ZAg4XAdZVdqy08eswHzo19Uk2QAAMDhCMAAAgyMEAwAwOEIwwBLyZVRYHfp1NQnBAzCLw5pdeNINYL6mOUzSbqdxBWZPv64mIXiN+csUVpsehtWhX1ePELzGZnWSi1mddAPY2U4Hzd/vxDqrg/gDO9Ov60GaWXPjDXmQIGw3CFis/UysPlaFxdKvq0UIXnNHjxw5cBC+MABbBYb52umg+VsT606T69Z9202oVpVgPvTr6tsz0VTVNVX1YFXdPjb2oaq6dXS5p6puHY1fVlV/O3bfz82zeCazXRDe7N4xDG92C8CwILtNrOOT6/hl674LmVBhvvTrajs2wWM+kORnk1y7NdDd/2zrelW9O8lfjj3+ru4+MasCmY2jR47k3OZmquqRsa0wvBcBGA7X1sS63UQ56cenJlQ4HPp1de2ZbLr7E0ke2u6+Op+oXp/kuhnXxRxsrQjvZ38lARgWY6cVpkmYUOFw6dfVNMlK8G5ekOSB7v6TsbFnVNXvJ/lKkv+ru//blK/BDG0F2kn2CxZ+YbGqKqdfd8nEX6AxmcLi6NfVM20IviqPXQW+P8nx7v5yVT0/ya9U1bd191cu3LCqTiU5NeXrc0ACLvs13rPHjx9fcDXDsjW5wqT06+Lo19Vx4CRUVceSvDbJh7bGuvtr3f3l0fVbktyV5Dnbbd/dp7v7ZHefPGgNwOEZ79mNjY1FlwPsQr/C3qZZDnxJks93971bA1W1UVVHR9efmeTyJHdPVyIAAMzWJIdIuy7JzUmeW1X3VtWbRne9IY//QtwLk9w2OmTaLyd5c3dv+6U6AABYlD33Ce7uq3YY/4Ftxq5Pcv30ZQEAwPz4dhQAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDU9296BpSVWeT/HWSP190LTPw1Hgfy2QV3sff7e6NRRexH1X11SR3LLqOGViFfx+T8D4O10r1rH5dOt7H4dqxX48ddiXb6e6NqjrT3ScXXcu0vI/lsi7vYwndsQ6/13X59+F9sAf9ukS8j+VhdwgAAAZHCAYAYHCWKQSfXnQBM+J9LJd1eR/LZl1+r97HclmX97Fs1uX36n0sl5V/H0vxxTgAADhMy7QSDAAAh0IIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZnbiG4ql5WVXdU1Z1V9bZ5vQ4AAOxXdffsn7TqaJI/TvLSJPcm+UySq7r7czN/MQAA2Kd5rQRfkeTO7r67u/9Xkl9K8uo5vRYAAOzLvELwJUm+OHb73tEYAAAs3LFFvXBVnUpyanTz+YuqA5bAn3f3xqKL2Mt4zz7hCU94/rd8y7csuCJYjFtuuWXpe1a/wnm79eu8QvB9SZ4+dvvS0dgjuvt0ktNJUlWz3zEZVscXFl3AJMZ79uTJk33mzJkFVwSLUVVL37P6Fc7brV/ntTvEZ5JcXlXPqKqvT/KGJB+b02sBAMC+zGUluLsfrqq3JvnNJEeTXNPdn53HawEAwH7NbZ/g7r4hyQ3zen4AADgoZ4wDAGBwhGAAVlZ3Zx4nfQJmb9n6VQiGMZvdObe5uegygAl0d6795JW59pNXLroUYA/L2K9CMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDc2zRBQDAuIOcVnU/21TVvp8f2N4q96sQzNrb3GeDVtXE23R3jh7xgQrMytapVfdrP9u88QU37/v5gcdb9X4Vgll7B/mL8yB/2QIAq0MIZu3tZ6V2s9vqLixQVU288jO+CmV1Fw7fqvfrgWf6qnp6VX28qj5XVZ+tqn89Gv/Jqrqvqm4dXV4xu3IBAGB606wEP5zkx7r796rqSUluqaobR/e9p7vfNX15AAAwewcOwd19f5L7R9e/WlV/lOSSWRUGAADzMpMdH6vqsiTfmeS/j4beWlW3VdU1VfWUWbwGAADMytQhuKqemOT6JD/c3V9J8t4kz0pyIudXit+9w3anqupMVZ2ZtgZg/sZ79uzZs4suB9iFfoW9TRWCq+rrcj4A/0J3fyRJuvuB7j7X3ZtJ3pfkiu227e7T3X2yu09OUwNwOMZ7dmNjY9HlALvQr7C3aY4OUUnen+SPuvtnxsYvHnvYa5LcfvDyAABg9qY5OsQ/SvL9Sf6wqm4djb09yVVVdSJJJ7knydVTVQgAADM2zdEhPplkuxM633DwcgAAYP6cFgvGOF0yAAyD0ybDGKdLhtWxn1O2Aou1jP1qxgcAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCOLboAAFZTd0/82KqaYyXAXvTr400dgqvqniRfTXIuycPdfbKqvjHJh5JcluSeJK/v7r+Y9rVYb5sTNmh35+gRH2LAInV3PvNd3zPx46/43Y/PsRpgN/p1e7NKEt/d3Se6++To9tuS3NTdlye5aXQbAACWwryW016d5IOj6x9M8n1zeh0AANi3WYTgTvJbVXVLVZ0ajV3U3fePrv9Zkotm8DoAADATs/hi3Hd1931V9c1Jbqyqz4/f2d1dVY/b2XMUmE9dOA4sp/GePX78+IKrAXajX2FvU68Ed/d9o58PJvlokiuSPFBVFyfJ6OeD22x3urtPju1HDCyx8Z7d2NhYdDnALvQr7G2qEFxVT6iqJ21dT/KPk9ye5GNJ3jh62BuT/Oo0rwMAALM07e4QFyX56Oh4cseS/GJ3/9eq+kySD1fVm5J8Icnrp3wdAACYmalCcHffneR524x/OcmLp3luAACYF2ccAABgcJw2maWxn1M6AgBMQwhmaTgVMqyOqhrMqVVh1enX7UkdAAAMjhAMAMDgCMHAXNjHG4BltjIh+Nzm5qJLACbU3Tly+pWCMKyI7tavDM7KhOCqEoRhxQjCsPy2/mjVrwzNyoTgRBCGVWRihdWhXxmSlQrBiSAMq8jECqtDvzIUKxeCE0EYVpGJFVaHfmUIVjIEJ4IwrCITK6wO/cq6W9kQnAjCsIpMrLA69CvrbKVDcCIIwyoyse5t65BVk1xgnvTr3vTrajq26AJmYSsIHz2y8pl+bjYnbLzu9nvkUBw5/cpsnvq1VNWiS1kqW5Pk1R/50sTb/Pxrn5YkfpfMjX7dnn5dbWsRghNBeCeTht8tVZXN0V+rfpfMm4n1sbp7X5Pplq1tfv61T/O7ZG7062Pp19V34JRTVc+tqlvHLl+pqh+uqp+sqvvGxl8xy4L3qMmuEWP2G4DH+V1yWHzU+uhHqQeZUMdd/ZEv+ciVudKv+nWdHHgluLvvSHIiSarqaJL7knw0yQ8meU93v2smFe6TFeHdw+9uzXbhX6RWhTksQ15hmsVkOs4qE/OmX/XrupjV7hAvTnJXd39hGf4jDjkI7xSAt8Lvbr+TrZXf7cLwUH+fHJ4hTqyTTKinX3fJgba9+iNf2nFbmJZ+3Z5+XS2zSjVvSHLd2O23VtVtVXVNVT1lRq+xLz7Of9TWSu5eIXbrMT6aYVF81Pqo06+7ZNdJsapy+nWXPPIlm+34XTJP+vVR+nU1TR2Cq+rrk7wqyX8ZDb03ybNyfleJ+5O8e4ftTlXVmao6M20Nu9Q2qCC83SrwQXZl2C4ID+13yeON9+zZs2fn9jpDmVh3Wxnaz4pQVe04sW7tc8jw6NfZ0q/raRYrwS9P8nvd/UCSdPcD3X2uuzeTvC/JFdtt1N2nu/tkd5+cQQ07Gkp4m1UA3iIIc6Hxnt3Y2Jjra637xDqrCXXLXhMrw6NfZ0e/rq9ZhOCrMrYrRFVdPHbfa5LcPoPXmMoQw9ssvsxm1wgWad0n1u1Ms0/gbhPr0H6PHD79uj/6dTlMlZKq6glJXprkI2PDP11Vf1hVtyX57iQ/Ms1rzMo6B+ELV4FneTSHC59nnX+PLJ91nFh3WlWaxZdidppYrS5xGPTr/ujXxZsqKXX3X3f3N3X3X46NfX93/73u/o7uflV33z99mbMhwMHqWceJdR4ee7zRHl3gcOnXyTg+8HJYmzPGTWrdD/c1q1XgaU60AbO27odjmnZVqbvzV39w9JHb73pm8psP/Uhu/J8/kmQ9f2csL/26u1kfa5iDW88kuAcrwrsTgFlGVpi2d2EA3lKVvPTJ74kVYRbhyOlXLrqEpSQAL5eVCcFbHx3M6pJEEN6GAMwyM7E+1k4BeMujQRhYNAF4+azM7hDruvsCrKOqSl/964suA5iAfmWoJEsAAAZHCAYAYHCEYB7DF4/g8Ok7WB36dX0IwWtm2iNfOEscHL6DflmmqvLE5517zFh38lt/8SOPXL/xf/7o1PUBj5qmX3c6SxyLIQSvgVmH1q0gLAzD4Tlov+0UhAVgmN5OwXWaft0rCM/ibHRMRgheA/M4tfHRI0e2PSLHLE/JDEO026lSp5lYn3RiM0983rn8n3d/MTf+zx99TAA2qcJsTduvVoSXgzSzJi5sxlkEYccMhsM1zcTqGKQwH/P4w5XlIASvie325Z0mCG8XgK0Cw2zsthJ0kIl1twBsFRjmR7+uNolmjWwXUA8ShAVgmL9JgvBek+vWY0yoMF9b/bpdT+rX1bUyZ4xjMt2dqnrMWFVlc9R8uwVZuz/AdA7y0ehO22xNlLvtO2j3B5jebn1bVY/cf+Wz35lvfumJfPNLnve4eVa/riYheM0cPXIk5zY3H9egyaNheD+2mt8qMOyuu3Pls995oG03XvK8JHlM31700hNJDj5xWlWCve3Vt5+68+2Puf/BG29NknzzNj2b6NdVIwQvuXObm48E2/3+TB7foPthFwg4HGd/+w+SnA/D0/RsYjKFSU3yh+t29z94462PfOq63arwfujXxZoo4VTVNVX1YFXdPjb2jVV1Y1X9yejnU0bjVVX/uarurKrbqurvz6v4IdgKoQf5uXX9IDvtC8Bw+M7+9h9M9W1zEypMZppPbpLzvfrgjbfmwVHPHqRv9eviTZpyPpDkZReMvS3JTd19eZKbRreT5OVJLh9dTiV57/RlDtfWiu5Bf26F2a0mneSy0zGCgfnbmlz34/TrLjGhwgI8eOOtuf3ffjAPjj7NmYR+XR4T7Q7R3Z+oqssuGH51kheNrqByt8EAAB4FSURBVH8wyf+X5N+Oxq/t838WfbqqnlxVF3f3/bMoeGimWQne7iewGkySsDoevPHW/OrP/dNFl8E+TZOMLhoLtn+W5KLR9UuSfHHscfeOxjiAaVeCpz1hBgCwu5vveseiS+AAZrI8OFr13dcOMVV1qqrOVNWZWdSwrqwAsyzGe/bs2bOLLgfYhX6FvU2TkB6oqouTZPTzwdH4fUmePva4S0djj9Hdp7v7ZHefnKKGtWclmGUx3rMbGxuLLgfYhX6FvU0Tgj+W5I2j629M8qtj4/9idJSIf5jkL+0PfHBWggEAZm/SQ6Rdl+TmJM+tqnur6k1J/mOSl1bVnyR5yeh2ktyQ5O4kdyZ5X5J/NfOqB8RKMADA7E16dIirdrjrxds8tpO8ZZqieJSVYACA2ZOQlpyVYACA2ROCl5yVYACA2ZOQlpyVYACA2ROCl5yVYACA2ZOQlpyVYACA2ROCl5yVYACA2ZOQlpyVYACYrarKp+58+0ye6+a73jGT5+HwCcFLzkowAMzeLIKwALzaJjpZBgC7qyoTIqwYfTtslgkBABgcIRgAgMERggEAGBwhGACAwRGCAQAYHCEYAIDBEYIBABicPUNwVV1TVQ9W1e1jY/+pqj5fVbdV1Uer6smj8cuq6m+r6tbR5efmWTwAABzEJCvBH0jysgvGbkzy7d39HUn+OMmPj913V3efGF3ePJsyAQBgdvYMwd39iSQPXTD2W9398Ojmp5NcOofaAABgLmaxT/APJfmNsdvPqKrfr6rfqaoXzOD5AQBgpo5Ns3FVvSPJw0l+YTR0f5Lj3f3lqnp+kl+pqm/r7q9ss+2pJKemeX3g8Iz37PHjxxdcDbAb/Qp7O/BKcFX9QJLvTfLPu7uTpLu/1t1fHl2/JcldSZ6z3fbdfbq7T3b3yYPWABye8Z7d2NhYdDnALvQr7O1AIbiqXpbk3yR5VXf/zdj4RlUdHV1/ZpLLk9w9i0IBAGBW9twdoqquS/KiJE+tqnuT/ETOHw3iG5LcWFVJ8unRkSBemOQ/VNX/TrKZ5M3d/dC2TwwAAAuyZwju7qu2GX7/Do+9Psn10xYFAADz5IxxAAAMjhAMAMDgCMEAAAyOEAwAwOAIwQAADI4QDADA4AjBAAAMjhAMAMDgCMEAAAyOEAwAwOAIwQAADI4QDADA4AjBAAAMjhAMAMDgCMEAAAyOEAwAwODsGYKr6pqqerCqbh8b+8mquq+qbh1dXjF2349X1Z1VdUdV/ZN5FQ4AAAc1yUrwB5K8bJvx93T3idHlhiSpqm9N8oYk3zba5v+tqqOzKhYAAGZhzxDc3Z9I8tCEz/fqJL/U3V/r7j9NcmeSK6aoDwAAZm6afYLfWlW3jXaXeMpo7JIkXxx7zL2jMQAAWBoHDcHvTfKsJCeS3J/k3ft9gqo6VVVnqurMAWsADtF4z549e3bR5QC70K+wtwOF4O5+oLvPdfdmkvfl0V0e7kvy9LGHXjoa2+45Tnf3ye4+eZAagMM13rMbGxuLLgfYhX6FvR0oBFfVxWM3X5Nk68gRH0vyhqr6hqp6RpLLk/yP6UoEAIDZOrbXA6rquiQvSvLUqro3yU8keVFVnUjSSe5JcnWSdPdnq+rDST6X5OEkb+nuc/MpHQAADmbPENzdV20z/P5dHv9TSX5qmqIAAGCenDEOAIDBEYIBABgcIRgAgMERggEAGBwhGACAwRGCAQAYHCEYAIDBEYIBABgcIRgAgMERggEAGBwhGACAwRGCAQAYHCEYAIDBEYIBABgcIRgAgMERggEAGJw9Q3BVXVNVD1bV7WNjH6qqW0eXe6rq1tH4ZVX1t2P3/dw8iwcAgIM4NsFjPpDkZ5NcuzXQ3f9s63pVvTvJX449/q7uPjGrAgEAYNb2DMHd/Ymqumy7+6qqkrw+yffMtiwAAJifafcJfkGSB7r7T8bGnlFVv19Vv1NVL5jy+QEAYOYm2R1iN1cluW7s9v1Jjnf3l6vq+Ul+paq+rbu/cuGGVXUqyakpXx84JOM9e/z48QVXA+xGv8LeDrwSXFXHkrw2yYe2xrr7a9395dH1W5LcleQ5223f3ae7+2R3nzxoDcDhGe/ZjY2NRZcD7EK/wt6m2R3iJUk+3933bg1U1UZVHR1df2aSy5PcPV2JAAAwW5McIu26JDcneW5V3VtVbxrd9YY8dleIJHlhkttGh0z75SRv7u6HZlkwAABMa5KjQ1y1w/gPbDN2fZLrpy8LAADmxxnjAAAYHCEYAIDBEYIBABgcIRgAgMERggEAGBwhGACAwRGCAQAYHCEYAIDBEYIBABgcIRgAgMERggEAGBwhGACAwanuXnQNqaqzSf46yZ8vupYZeGq8j2WyCu/j73b3xqKL2I+q+mqSOxZdxwyswr+PSXgfh2ulela/Lh3v43Dt2K/HDruS7XT3RlWd6e6Ti65lWt7HclmX97GE7liH3+u6/PvwPtiDfl0i3sfysDsEAACDIwQDADA4yxSCTy+6gBnxPpbLuryPZbMuv1fvY7msy/tYNuvye/U+lsvKv4+l+GIcAAAcpmVaCQYAgEMhBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDM7cQXFUvq6o7qurOqnrbvF4HAAD2q7p79k9adTTJHyd5aZJ7k3wmyVXd/bmZvxgAAOzTvFaCr0hyZ3ff3d3/K8kvJXn1nF4LAAD25dicnveSJF8cu31vkn8w/oCqOpXk1Ojm8+dUB6yCP+/ujUUXsZfxnn3CE57w/G/5lm9ZcEWwGLfccsvS96x+hfN269d5heA9dffpJKeTpKpmv08GrI4vLLqASYz37MmTJ/vMmTMLrggWo6qWvmf1K5y3W7/Oa3eI+5I8fez2paMxAABYuHmF4M8kubyqnlFVX5/kDUk+NqfXAgCAfZnL7hDd/XBVvTXJbyY5muSa7v7sPF4LAAD2a277BHf3DUlumNfzAwDAQTljHAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMECS7l50CQAcIiF4Ts5tbi66BGBC3Z0jp18pCMOK6G79ytSE4DmpKkEYVowgDMtv649W/cq0hOA5EoRh9ZhYYXXoV6Zx4BBcVU+vqo9X1eeq6rNV9a9H4z9ZVfdV1a2jyytmV+7qEYRh9ZhYYXXoVw5qmpXgh5P8WHd/a5J/mOQtVfWto/ve090nRpcbpq5yxQnCsHpMrLA69CsHceAQ3N33d/fvja5/NckfJblkVoWtG0EYVo+JFVaHfmW/ZrJPcFVdluQ7k/z30dBbq+q2qrqmqp6ywzanqupMVZ2ZRQ2rQBBmlY337NmzZxddzqExsbKK9CvsbeoQXFVPTHJ9kh/u7q8keW+SZyU5keT+JO/ebrvuPt3dJ7v75LQ1rBJBmFU13rMbGxuLLudQmVhZNfpVv7K3qUJwVX1dzgfgX+jujyRJdz/Q3ee6ezPJ+5JcMX2Z60UQhtVjYoXVoV+ZxDRHh6gk70/yR939M2PjF4897DVJbj94eetLEIbVY2KF1aFf2cs0K8H/KMn3J/meCw6H9tNV9YdVdVuS707yI7ModB0JwrB6TKywOvQruzl20A27+5NJapu7Bn9ItP3YCsJHjzhvCayKI6dfmc1Tv5bzH4gBy0y/shPJawlYEYbVY4UJVod+ZTtC8JIQhGH1mFhhdehXLiQELxFBGFaPiRVWh35lnBC8ZARhWD0mVlgdR06/ctElsCSE4Dnp7gNfkgjCsGJMrACr5cBHh2B3jvYAq6Oq0lf/+qLLACagX5kVSQ0AgMERggEAGBwhGACAwRGCAQAYHCEYAIDBEYIBABgcIRgAgMERggEAGBwhGACAwRGCAQAYnKlPm1xV9yT5apJzSR7u7pNV9Y1JPpTksiT3JHl9d//FtK8FAACzMKuV4O/u7hPdfXJ0+21Jburuy5PcNLoNAABLYV67Q7w6yQdH1z+Y5Pvm9DoAALBvswjBneS3quqWqjo1Gruou+8fXf+zJBdduFFVnaqqM1V1ZgY1AHM23rNnz55ddDnALvQr7G0WIfi7uvvvJ3l5krdU1QvH7+zuzvmgnAvGT3f3ybFdKIAlNt6zGxsbiy4H2IV+hb1NHYK7+77RzweTfDTJFUkeqKqLk2T088FpXwcAAGZlqhBcVU+oqidtXU/yj5PcnuRjSd44etgbk/zqNK8DAACzNO0h0i5K8tGq2nquX+zu/1pVn0ny4ap6U5IvJHn9lK8DAAAzM1UI7u67kzxvm/EvJ3nxNM8NAADz4oxxAAAMjhAMAMDgCMEAAAyOEAwAwOAIwQAADI4QDADA4AjBAAAMjhAMAMDgCMEAAAyOEAwAwOAIwQAADI4QDADA4AjBAAAMjhAMAMDgCMEAAAyOEAwAwOAcO+iGVfXcJB8aG3pmkn+X5MlJ/mWSs6Pxt3f3DQeuEAAAZuzAIbi770hyIkmq6miS+5J8NMkPJnlPd79rJhUCAMCMzWp3iBcnuau7vzCj5wMAgLmZVQh+Q5Lrxm6/tapuq6prquop221QVaeq6kxVnZlRDcAcjffs2bNn994AWBj9CnubOgRX1dcneVWS/zIaem+SZ+X8rhL3J3n3dtt19+nuPtndJ6etAZi/8Z7d2NhYdDnALvQr7G0WK8EvT/J73f1AknT3A919rrs3k7wvyRUzeA0AAJiZWYTgqzK2K0RVXTx232uS3D6D1wAAgJk58NEhkqSqnpDkpUmuHhv+6ao6kaST3HPBfQAAsHBTheDu/usk33TB2PdPVREAAMyZM8YBADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEA6yY7k53L7oMYAL6dXkJwWSzO+c2NxddBjCB7s61n7wy137yykWXAuxBvy43IRgAgMERggEAGJxjiy4AAID5mHZ/5KqaUSXLRwgGYGKz+ILPOk+qsEy6O1c++51TPcen7nz72vbsRCG4qq5J8r1JHuzubx+NfWOSDyW5LMk9SV7f3X9R539T/0+SVyT5myQ/0N2/N/vSWQfnNjdz9MiRqX8C8zeLCTVJbr7rHTOoBtjNrPr1yme/c217dtL08IEkL7tg7G1Jburuy5PcNLqdJC9PcvnocirJe6cvk3W1FWCn/QkAnDerALzuJkoQ3f2JJA9dMPzqJB8cXf9gku8bG7+2z/t0kidX1cWzKJb1s3Votml/AgDsxzTLaBd19/2j63+W5KLR9UuSfHHscfeOxh6jqk5V1ZmqOjNFDaw4K8GrY7xnz549u+hygF3oV9jbTBJEn/+mxL6+LdHdp7v7ZHefnEUNrCYrwatjvGc3NjYWXQ6wC/0Ke5vm6BAPVNXF3X3/aHeHB0fj9yV5+tjjLh2NweNYCYbzDnLUhf1ss67f7oZF0K/rYZoQ/LEkb0zyH0c/f3Vs/K1V9UtJ/kGSvxzbbYJDsrnPBq2qibfp7pmFT0eHgEdPrbpf+9nmjS+4ed/PDzyefl0fkx4i7bokL0ry1Kq6N8lP5Hz4/XBVvSnJF5K8fvTwG3L+8Gh35vwh0n5wxjUzgYP8xTmL43/ul5VgAFhu3b2Wq9MTheDuvmqHu168zWM7yVumKYrp7SccbnbPdHV3P6wEw/k/RCdd+RlfhVrV1aKD/MG9jhMwq2lo/ZqcP1bwp+58+8SPX5V+dcY4FmoWK8H73fVjUYEfSJLOZ77re/a91RW/+/E51ALsrfOyJ969r75dlX6VBFgoR4cAgGV1PgC//Il/uuhC5kIIZqHsEwwALIIEwUJZCQYAFkEIZqGsBAMAiyBBsFBWggGARRCCWSgrwQDAIkgQLJSVYBia1Th+KKyyqtrXcX2HynGCWahFHCcYOJiqys13vWOq5+jufOa7fntGFQE72QrCVz77ndM8y9oeHi0RgsliTpcMAMzXtH+4rvsfrUIw9quFFbKfU7YCi6Vfl5v0AwDA4FgJZuXZnQMA2C8hmJVndw5YHVWVK37344suA5jAuver9AAAwOAIwQAADM6eIbiqrqmqB6vq9rGx/1RVn6+q26rqo1X15NH4ZVX1t1V16+jyc/MsHgAADmKSleAPJHnZBWM3Jvn27v6OJH+c5MfH7ruru0+MLm+eTZkAADA7e4bg7v5EkocuGPut7n54dPPTSS6dQ20AADAXs9gn+IeS/MbY7WdU1e9X1e9U1Qt22qiqTlXVmao6M4MagDkb79mzZ88uuhxgF/oV9jZVCK6qdyR5OMkvjIbuT3K8u78zyY8m+cWq+jvbbdvdp7v7ZHefnKYG4HCM9+zGxsaiywF2oV9hbwc+TnBV/UCS703y4h6draC7v5bka6Prt1TVXUmek8RqL7Avk5wEpaoOoRJgL/qVVXSgEFxVL0vyb5L8H939N2PjG0ke6u5zVfXMJJcnuXsmlU5pc48G7W4nXYAlsDWZXv2RL+352J9/7dOSmFxhUfQrq2zPEFxV1yV5UZKnVtW9SX4i548G8Q1Jbhz9Y/706EgQL0zyH6rqfyfZTPLm7n5o2yc+JHuF3y1Vlc1uYRgWqLsnmky3bD3251/7NBMrHDL9yqrbMwR391XbDL9/h8den+T6aYuale0C8HYf2Yw3Y1Xl3OamIAyHbL8T6rirP/IlEyscIv3KOjjwPsHL7sIAvBV+twu35zY3kzwahgVhOFy7TainX3fJRI81scLh0K+si7VMedsF4KNHjuwYarfuG18l3grCwHztNEmeft0lj5tQk/O9efp1lzyyf+G4qz/ypYm+oAMcjH5lnaxdCN4pAE9CEIbDtduEupeq2nFiBWZPv7Ju1i4EjzvIl9wuDMLA4ZpkQt2y08Sqh+Fw6FdW2VqF4PFV4GmO8jC+ndVgmI/tVpX2M6Fu2W5itboEs6VfWUdrFYJnyV+mAADray1D8CyO9evIEHC4DrKqtGWnj1mB+dCvrANJDwCAwRGCAQAYHCEYAIDBEYKBpeDLqLA69CvrYC1D8CwOa3bhSTeA+ZrmMEm7ncZ11XT3nhdYNP16nn5dbccWXcAsdbfzkF9grzA/iyNpwKwMvYe7O3/1B0f3fNwTn3du0L8nloN+1a+rbq3Sz6xOcjGrk24s2iSr2U4GwqLsdND8/a6czOog/os26YSaJH/1B0etMHGo9Otj6df1sJrpbhfj/9AOEvDWZTeI/bwPQZhlsp+JdV0+Vt3PhLrFxMoy0K+T0a/Lae1C8NEjRw4chC8Mjqu6CnyQIC8Iswg7HTR/a2LdadLYum+7CXXVVpUOMqFuMbFymPSrfl03eya8qrqmqh6sqtvHxn6yqu6rqltHl1eM3ffjVXVnVd1RVf9kXoXvZrsgvNm9Y8jb7B50AN4iCLMIu02s45Pr+GXrvgsNaULdMu32sB/6Vb+uk0m+GPeBJD+b5NoLxt/T3e8aH6iqb03yhiTfluRpSX67qp7T3edmUOu+HD1yJOc2Nx+zM/pWGN7LqgZgWFVbE+t2E+WkH5+u2oQKq0q/si72THrd/YkkD034fK9O8kvd/bXu/tMkdya5Yor6prK1Iryf/ZUEYFiMnVaYJmFChcOlX1kH0xwi7a1V9S+SnEnyY939F0kuSfLpscfcOxp7nKo6leTUFK8/ka1AO8nH/MIv7Gy8Z48fPz6v18jp110y8RdoTKawPf0KeztoCH5vkv87SY9+vjvJD+3nCbr7dJLTSVJVc99TXMCF6Yz37MmTJ+fas1uTK3Aw+hX2dqBk2N0PdPe57t5M8r48usvDfUmePvbQS0djAACwNA4Ugqvq4rGbr0mydeSIjyV5Q1V9Q1U9I8nlSf7HdCUCAMBs7bk7RFVdl+RFSZ5aVfcm+YkkL6qqEzm/O8Q9Sa5Oku7+bFV9OMnnkjyc5C2LODIEAADsZs8Q3N1XbTP8/l0e/1NJfmqaoji4oZ/LHQBgEr4ttmYuPFHIfjg8HByuqsoTnzfdh2VPOuEEN3AY9Ov6kXjW0EGCsAAMizHNxGpChcOlX9eL1LOm9hOEBWBYrINMrCZUWAz9uj4knzU2fsa83S4CMCzefiZWEyosln5dD9OcMY4VIODC6qgqEyasCP26+iQkAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABkcIBgBgcIRgAAAGRwgGAGBwhGAAAAZHCAYAYHCEYAAABmfPEFxV11TVg1V1+9jYh6rq1tHlnqq6dTR+WVX97dh9PzfP4gEA4CCOTfCYDyT52STXbg109z/bul5V707yl2OPv6u7T8yqQAAAmLU9Q3B3f6KqLtvuvqqqJK9P8j2zLQsAAOZn2n2CX5Dkge7+k7GxZ1TV71fV71TVC3basKpOVdWZqjozZQ3AIRjv2bNnzy66HGAX+hX2Nm0IvirJdWO3709yvLu/M8mPJvnFqvo7223Y3ae7+2R3n5yyBuAQjPfsxsbGossBdqFfYW8HDsFVdSzJa5N8aGusu7/W3V8eXb8lyV1JnjNtkQAAMEvTrAS/JMnnu/verYGq2qiqo6Prz0xyeZK7pysRAABma5JDpF2X5OYkz62qe6vqTaO73pDH7gqRJC9MctvokGm/nOTN3f3QLAsGAIBpTXJ0iKt2GP+BbcauT3L99GUBAMD8OGMcAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAgyMEAwAwOEIwAACDIwQDADA4QjAAAIMjBAMAMDhCMAAAg1PdvegaUlVnk/x1kj9fdC0z8NR4H8tkFd7H3+3ujUUXsR9V9dUkdyy6jhlYhX8fk/A+DtdK9ax+XTrex+HasV+PHXYl2+nujao6090nF13LtLyP5bIu72MJ3bEOv9d1+ffhfbAH/bpEvI/lYXcIAAAGRwgGAGBwlikEn150ATPifSyXdXkfy2Zdfq/ex3JZl/exbP7/9u7eRaorDuP49yEQixCIWohEC5VttFkkiEUQ0qhrs6ZbGy0EG/0DDDb5B0JAiBYBWZMidqKFhS9NKlEDG1+K1fUFdFG3CKTUkPxS3LPkZrNXnQ3ec87M84HLzNyZhefcOw8c7svssGxXj6Ms1Y+jiBvjzMzMzMz6VNKRYDMzMzOzXmSfBEvaK2lW0pyk47nzDELSE0l3JM1IupXWrZF0RdKD9Lg6d86lJJ2RtCDpbmvdsrnVOJn2z21J2/Ml/7eOcXwtaT7tkxlJ+1rvfZXGMStpT57U9XNn++fOurMr5b72z32tp69ZJ8GSPgC+AyaArcABSVtzZlqBLyJivPUzIceBaxExBlxLr0szDexdsq4r9wQwlpYjwOmeMr6Laf47DoBv0z4Zj4hLAOl7NQVsS39zKn3/bADubDbTuLPu7IDc12ymcV+r6GvuI8E7gLmIeBQRr4FzwGTmTP/XJHA2PT8L7M+YZVkR8TPw25LVXbkngR+icR34RNL6fpK+Wcc4ukwC5yLiVUQ8BuZovn82GHc2A3fWnV0h9zUD97WevuaeBH8KPG29fpbW1SKAy5J+kXQkrVsXEc/T8xfAujzRBtaVu8Z9dCydVjrTOlVW4zhKVPt2dGfL5M6+H7VvQ/e1TEPT19yT4Np9HhHbaU5nHJW0q/1mND+9Ud3Pb9SaOzkNbAHGgefAN3njWGHc2fK4s9bFfS3PUPU19yR4HtjYer0hratCRMynxwXgPM2h/5eLpzLS40K+hAPpyl3VPoqIlxHxZ0T8BXzPP6djqhpHwareju5sedzZ96rqbei+lmfY+pp7EnwTGJO0SdKHNBdVX8yc6Z1I+kjSx4vPgd3AXZr8h9LHDgEX8iQcWFfui8DBdAfrTuD31imd4iy5lupLmn0CzTimJK2StInmJoQbfecbAu5sOdxZexv3tRzua4kiIusC7APuAw+BE7nzDJB7M/BrWu4tZgfW0tz5+QC4CqzJnXWZ7D/RnMb4g+a6ncNduQHR3F38ELgDfJY7/1vG8WPKeZumlOtbnz+RxjELTOTOX+vizmbJ7s66syvd5u5r/9nd10r66v8YZ2ZmZmYjJ/flEGZmZmZmvfMk2MzMzMxGjifBZmZmZjZyPAk2MzMzs5HjSbCZmZmZjRxPgs3MzMxs5HgSbGZmZmYjx5NgMzMzMxs5fwPhfH5V5PwNCQAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 864x864 with 9 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Wsfxcw0-DZdn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JPG4VNTnFr3p",
        "colab_type": "text"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "Try tweaking the hyper-parameters for better accuracy e.g.\n",
        "\n",
        "- learning rates and schedules\n",
        "- loss weights\n",
        "- unfreezing layers\n",
        "- batch size\n",
        "- etc."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7VHV2fS4GRd-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 15,
      "outputs": []
    }
  ]
}